Repository: zoomin-lee/scene-scale-diffusion Branch: main Commit: a23745f37edd Files: 30 Total size: 163.3 KB Directory structure: gitextract_ghos05eb/ ├── .gitignore ├── LICENSE.txt ├── README.md ├── SSC_train.py ├── __init__.py ├── datasets/ │ ├── carla.yaml │ ├── carla_dataset.py │ └── data.py ├── layers/ │ ├── Ablation/ │ │ └── wo_diffusion.py │ ├── Latent_Level/ │ │ ├── stage1/ │ │ │ ├── model.py │ │ │ ├── vector_quantizer.py │ │ │ └── vqvae.py │ │ └── stage2/ │ │ ├── Gen_diffusion.py │ │ └── gen_denoise.py │ ├── Voxel_Level/ │ │ ├── Con_Diffusion.py │ │ ├── Gen_Diffusion.py │ │ ├── denoise.py │ │ └── gen_denoise.py │ └── __init__.py ├── requirements.txt ├── setup.py ├── simple_visualize.py ├── train.py ├── utils/ │ ├── cuda.py │ ├── dicts.py │ ├── intermediate_vis.py │ ├── loss.py │ ├── multistep.py │ └── tables.py └── visualization.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging build/ dist/ *.egg-info/ .eggs/ # Virtual environment venv/ env/ .venv/ .env/ # Jupyter Notebook checkpoints .ipynb_checkpoints/ # PyInstaller *.manifest *.spec # pytest .cache/ .pytest_cache/ # mypy .mypy_cache/ # coverage htmlcov/ .coverage .coverage.* # logs and temporary files *.log *.tmp *.bak # IDEs and editors .vscode/ .idea/ *.sublime-workspace *.sublime-project # OS files .DS_Store Thumbs.db # dotenv / secrets .env .env.* ================================================ FILE: LICENSE.txt ================================================ MIT License Copyright (c) 2023 jumin Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data 📌[Paper](http://arxiv.org/abs/2301.00527)   Comparison of object-scale and scene scale generation (ours). Our result includes multiple objects in a generated scene, while the object-scale generation crafts one object at a time. (a) is obtained by [Point-E](https://github.com/openai/point-e) ## Abstract In this paper, we learn a diffusion model to generate 3D data on a scene-scale. Specifically, our model crafts a 3D scene consisting of multiple objects, while recent diffusion research has focused on a single object. To realize our goal, we represent a scene with discrete class labels, i.e., categorical distribution, to assign multiple objects into semantic categories. Thus, we extend discrete diffusion models to learn scene-scale categorical distributions. In addition, we validate that a latent diffusion model can reduce computation costs for training and deploying. To the best of our knowledge, our work is the first to apply discrete and latent diffusion for 3D categorical data on a scene-scale. We further propose to perform semantic scene completion (SSC) by learning a conditional distribution using our diffusion model, where the condition is a partial observation in a sparse point cloud. In experiments, we empirically show that our diffusion models not only generate reasonable scenes, but also perform the scene completion task better than a discriminative model. ## Instructions ### Dataset : We use [CarlaSC](https://umich-curly.github.io/CarlaSC.github.io/download/) cartesian dataset. ### Training : There are some argparse in 'SSC_train.py'. python SSC_train.py - For **multi-GPU** : --distribution True - For **Discrete Diffusion Model** : --mode gen/con/vis - For **Latent Diffusion Model** : --mode l_vae/l_gen --l_size 882/16162/32322 --init_size 32 --l_attention True --vq_size 100 Example for training l_gen mode python SSC_train.py --mode l_gen --vq_size 100 --l_size 32322 --init_size 32 --l_attention True --log_path ./result --vqvae_path ./lst_stage.tar ### Visualization : We save the result to a txt file using the `utils/table.py/visulization` function. If you use open3d, you will be able to easily visualize it. ## Result ### 3D Scene Generation ![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/3D_scene_generation.png?raw=true) ### Semantic Scene Completion ![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/table4.PNG?raw=true) ![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/semantic_scene_completion.png?raw=true) ## Acknowledgments This project is based on the following codebase. - [Multinomial Diffusion](https://github.com/ehoogeboom/multinomial_diffusion/tree/9d907a60536ad793efd6d2a6067b3c3d6ba9fce7) - [MotionSC](https://github.com/UMich-CURLY/3DMapping) - [Cylinder3D](https://github.com/xinge008/Cylinder3D) ================================================ FILE: SSC_train.py ================================================ import argparse import os import warnings import time import torch from utils.intermediate_vis import Vis_iter from datasets.data import * from utils.cuda import launch from utils.multistep import get_optim from train import Experiment from layers.Voxel_Level.Gen_Diffusion import Diffusion from layers.Voxel_Level.Con_Diffusion import Con_Diffusion from layers.Latent_Level.stage1.vqvae import vqvae from layers.Latent_Level.stage2.Gen_diffusion import latent_diffusion from layers.Ablation.wo_diffusion import wo_diff # environment variables NODE_RANK = os.environ['AZ_BATCHAI_TASK_INDEX'] if 'AZ_BATCHAI_TASK_INDEX' in os.environ else 0 NODE_RANK = int(NODE_RANK) MASTER_ADDR, MASTER_PORT = os.environ['AZ_BATCH_MASTER_NODE'].split(':') if 'AZ_BATCH_MASTER_NODE' in os.environ else ("127.0.0.1", 29500) MASTER_PORT = int(MASTER_PORT) DIST_URL = 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT) def get_args(): ########### ## Setup ## ########### parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=None, help='GPU id to use. If given, only the specific gpu will be used, and ddp will be disabled') parser.add_argument('--distribution', type=bool, default=True) parser.add_argument('--num_node', type=int, default=1, help='number of nodes for distributed training') parser.add_argument('--node_rank', type=int, default=0, help='node rank for distributed training') parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:29500', help='url used to set up distributed training') # Data params parser.add_argument('--dataset', type=str, default='carla', choices='carla') parser.add_argument('--dataset_dir', type=str, required=True, help='Path to the dataset directory') # Train params parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--pin_memory', type=eval, default=False) parser.add_argument('--augmentation', type=str, default=None) # Experiemtn params parser.add_argument('--clip_value', type=float, default=None) parser.add_argument('--clip_norm', type=float, default=None) parser.add_argument('--recon_loss', default=False) parser.add_argument('--mode', default='wo_diff', choices='gen, con, vis, l_vae l_gen, wo_diff') parser.add_argument('--l_size', default='32322', choices=['882', '16162', '32322']) parser.add_argument('--init_size', type=int, default=8) parser.add_argument('--l_attention', default=True) parser.add_argument('--vq_size', type=int, default=50) # Model params parser.add_argument('--auxiliary_loss_weight', type=int, default=0.0005) parser.add_argument('--diffusion_steps', type=int, default=100) parser.add_argument('--diffusion_dim', type=int, default=32) parser.add_argument('--dp_rate', type=float, default=0.) # Optim params parser.add_argument('--optimizer', type=str, default='adam') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--warmup', type=int, default=None) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--momentum_sqr', type=float, default=0.999) parser.add_argument('--milestones', type=eval, default=[]) parser.add_argument('--gamma', type=float, default=0.1) # Train params parser.add_argument('--epochs', type=int, default=5000) parser.add_argument('--resume', type=str, default=False) parser.add_argument('--resume_path', type=str, default='') parser.add_argument('--vqvae_path', type=str, default='') # Logging params parser.add_argument('--eval_every', type=int, default=10) parser.add_argument('--check_every', type=int, default=5) parser.add_argument('--completion_epoch', type=int, default=20) parser.add_argument('--log_tb', type=eval, default=True) parser.add_argument('--log_home', type=str, default=None) parser.add_argument('--log_path', type=str, default='') args = parser.parse_args() return args def main(): print('start!') args = get_args() if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely disable ddp.') torch.cuda.set_device(args.gpu) args.ngpus_per_node = 1 args.world_size = 1 else: if args.num_node == 1: args.dist_url == "auto" else: assert args.num_node > 1 args.ngpus_per_node = torch.cuda.device_count() args.world_size = args.ngpus_per_node * args.num_node launch(start, args.ngpus_per_node, args.num_node, args.node_rank, args.dist_url, args=(args,)) def start(local_rank, args): args.local_rank = local_rank args.global_rank = args.local_rank + args.node_rank * args.ngpus_per_node args.distributed = args.world_size > 1 ################## ## Specify data ## ################## train_loader, eval_loader, test_loader, num_classes, comp_weights, seg_weights, train_sampler = get_data(args) args.num_classes = num_classes completion_criterion = torch.nn.CrossEntropyLoss(weight=comp_weights) seg_criterion = torch.nn.CrossEntropyLoss(weight=seg_weights, ignore_index=0) similarity_criterion = torch.nn.MSELoss() ####################### ## Without Diffusion ## ####################### if args.mode == 'wo_diff': model = wo_diff(args, completion_criterion).cuda() if args.distribution : model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) ######################## ## Discrete Diffusion ## ######################## elif args.mode == 'gen': model = Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda() if args.distribution : model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) elif args.mode == 'con': model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda() if args.distribution : model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) ###################### ## Latent Diffusion ## ###################### elif args.mode == 'l_vae': model = vqvae(args, completion_criterion).cuda() if args.distribution: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) elif args.mode == 'l_gen': Dense = vqvae(args, completion_criterion).cuda() dense_check = torch.load(args.vqvae_path) model = latent_diffusion(args, Dense, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda() if args.distribution: Dense = torch.nn.parallel.DistributedDataParallel(Dense, device_ids=[args.gpu], find_unused_parameters=False) Dense.module.load_state_dict(dense_check['model']) for p in Dense.module.parameters(): p.requires_grad = False model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) ################### ## Visualization ## ################### elif args.mode == 'vis': model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda() if args.distribution : model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) optimizer, scheduler_iter, scheduler_epoch = get_optim(args, model) if args.mode == 'vis': exp = Vis_iter(args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader, args.log_path) else : exp = Experiment(args, model, optimizer, scheduler_iter, scheduler_epoch, train_loader, eval_loader, test_loader, train_sampler, args.log_path, args.eval_every, args.check_every) exp.run(epochs = args.epochs) if __name__ == '__main__': main() ================================================ FILE: __init__.py ================================================ ================================================ FILE: datasets/carla.yaml ================================================ color_map : 0 : [255, 255, 255] # None 1 : [70, 70, 70] # Building 2 : [100, 40, 40] # Fences 3 : [55, 90, 80] # Other 4 : [255, 255, 0 ] # Pedestrian 5 : [153, 153, 153] # Pole 6 : [157, 234, 50] # RoadLines 7 : [0, 0, 255] # Road 8 : [255, 255, 255] # Sidewalk 9 : [0, 155, 0] # Vegetation 10 : [255, 0, 0] # Vehicle 11 : [102, 102, 156] # Wall 12 : [220, 220, 0] # TrafficSign 13 : [70, 130, 180] # Sky 14 : [255, 255, 255] # Ground 15 : [150, 100, 100] # Bridge 16 : [230, 150, 140] # RailTrack 17 : [180, 165, 180] # GuardRail 18 : [250, 170, 30] # TrafficLight 19 : [110, 190, 160] # Static 20 : [170, 120, 50] # Dynamic 21 : [45, 60, 150] # Water 22 : [145, 170, 100] # Terrain learning_map : 0 : 0 1 : 1 2 : 2 3 : 3 4 : 4 5 : 5 6 : 6 7 : 6 8 : 8 9 : 9 10: 10 11 : 2 12 : 5 13 : 3 14 : 7 15 : 3 16 : 3 17 : 2 18 : 5 19 : 3 20 : 3 21 : 3 22 : 7 remap_color_map: 0 : [255, 255, 255] # None 1 : [255, 200, 0] # Building 2 : [255, 120, 50] # Fences 3 : [55, 90, 80] # Other 4 : [255, 30, 30] # Pedestrian 5 : [255, 240, 150] # Pole 6 : [255, 0, 255] # Road 7 : [175, 0, 75] # Ground 8 : [75, 0, 75] # Sidewalk 9 : [0, 175, 0] # Vegetation 10 : [100, 150, 245] # Vehicle label_to_names: 0 : Free 1 : Building 2 : Barrier 3 : Other 4 : Pedestrian 5 : Pole 6 : Road 7 : Ground 8 : Sidewalk 9 : Vegetation 10 : Vehicle content : 0 : 4166593275 1 : 42309744 2 : 8550180 3 : 478193 4 : 905663 5 : 2801091 6 : 6452733 7 : 229316930 8 : 112863867 9 : 29816894 10: 13839655 11 : 15581458 12 : 221821 13 : 0 14 : 7931550 15 : 467989 16 : 3354 17 : 9201043 18 : 61011 19 : 3796746 20 : 3217865 21 : 215372 22 : 79669695 remap_content : 0 : 4.16659328e+09 1 : 4.23097440e+07 2 : 3.33326810e+07 3 : 8.17951900e+06 4 : 9.05663000e+05 5 : 3.08392300e+06 6 : 2.35769663e+08 7 : 8.76012450e+07 8 : 1.12863867e+08 9 : 2.98168940e+07 10 : 1.38396550e+07 ================================================ FILE: datasets/carla_dataset.py ================================================ import os import numpy as np import random import json import yaml import torch import numba as nb from torch.utils.data import Dataset base_dir = os.path.dirname(__file__) config_file = os.path.join(base_dir, 'carla.yaml') carla_config = yaml.safe_load(open(config_file, 'r')) LABELS_REMAP = carla_config["learning_map"] REMAP_FREQUENCIES = carla_config["remap_content"] FREQUENCIES= carla_config["content"] LABELS_REMAP = np.asarray(list(LABELS_REMAP.values())) frequencies_cartesian = np.asarray(list(FREQUENCIES.values())) remap_frequencies_cartesian = np.asarray(list(REMAP_FREQUENCIES.values())) class CarlaDataset(Dataset): """Carla Simulation Dataset for 3D mapping project Access to the processed data, including evaluation labels predictions velodyne poses times """ def __init__(self, directory, voxelize_input=True, binary_counts=True, random_flips=False, remap=True, num_frames=1, transform_pose=True, get_gt=True, ): '''Constructor. Parameters: directory: directory to the dataset ''' self.get_gt = get_gt self.voxelize_input = voxelize_input self.binary_counts = binary_counts self._directory = directory self._num_frames = num_frames self.random_flips = random_flips self.remap = remap self.transform_pose = transform_pose self.sparse_output = True self._scenes = sorted(os.listdir(self._directory)) self._scenes = [os.path.join(scene, "cartesian") for scene in self._scenes] self._num_scenes = len(self._scenes) self._num_frames_scene = [] param_file = os.path.join(self._directory, self._scenes[0], 'evaluation', 'params.json') with open(param_file) as f: self._eval_param = json.load(f) self._out_dim = self._eval_param['num_channels'] self._grid_size = self._eval_param['grid_size'] self.grid_dims = np.asarray(self._grid_size) self._eval_size = list(np.uint32(self._grid_size)) self.coor_ranges = self._eval_param['min_bound'] + self._eval_param['max_bound'] self.voxel_sizes = [abs(self.coor_ranges[3] - self.coor_ranges[0]) / self._grid_size[0], abs(self.coor_ranges[4] - self.coor_ranges[1]) / self._grid_size[1], abs(self.coor_ranges[5] - self.coor_ranges[2]) / self._grid_size[2]] self.min_bound = np.asarray(self.coor_ranges[:3]) self.max_bound = np.asarray(self.coor_ranges[3:]) self.voxel_sizes = np.asarray(self.voxel_sizes) self._velodyne_list = [] self._label_list = [] self._pred_list = [] self._eval_labels = [] self._eval_counts = [] self._frames_list = [] self._timestamps = [] self._poses = [] for scene in self._scenes: velodyne_dir = os.path.join(self._directory, scene, 'velodyne') label_dir = os.path.join(self._directory, scene, 'labels') pred_dir = os.path.join(self._directory, scene, 'predictions') eval_dir = os.path.join(self._directory, scene, 'evaluation') self._num_frames_scene.append(len(os.listdir(velodyne_dir))) frames_list = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(velodyne_dir))] self._frames_list.extend(frames_list) self._velodyne_list.extend([os.path.join(velodyne_dir, str(frame).zfill(6)+'.bin') for frame in frames_list]) self._label_list.extend([os.path.join(label_dir, str(frame).zfill(6)+'.label') for frame in frames_list]) self._pred_list.extend([os.path.join(pred_dir, str(frame).zfill(6)+'.bin') for frame in frames_list]) self._eval_labels.extend([os.path.join(eval_dir, str(frame).zfill(6)+'.label') for frame in frames_list]) self._eval_counts.extend([os.path.join(eval_dir, str(frame).zfill(6) + '.bin') for frame in frames_list]) self._timestamps.append(np.loadtxt(os.path.join(self._directory, scene, 'times.txt'))) self._poses.append(np.loadtxt(os.path.join(self._directory, scene, 'poses.txt'))) # for poses and timestamps self._timestamps = np.array(self._timestamps).reshape(sum(self._num_frames_scene)) self._poses = np.array(self._poses).reshape(sum(self._num_frames_scene), 12) self._cum_num_frames = np.cumsum(np.array(self._num_frames_scene) - self._num_frames + 1) # Use all frames, if there is no data then zero pad def __len__(self): return sum(self._num_frames_scene) def collate_fn(self, data): voxel_batch = [bi[0] for bi in data] output_batch = [bi[1] for bi in data] counts_batch = [bi[2] for bi in data] return voxel_batch, output_batch, counts_batch def points_to_voxels(self, voxel_grid, points, t_i): # Valid voxels (make sure to clip) voxels = np.floor((points - self.min_bound) / self.voxel_sizes).astype(np.int32) # Clamp to account for any floating point errors maxes = np.reshape(self.grid_dims - 1, (1, 3)) mins = np.zeros_like(maxes) voxels = np.clip(voxels, mins, maxes).astype(np.int32) # This line is needed to create a mask with number of points, not just binary occupied if self.binary_counts: voxel_grid[t_i, voxels[:, 0], voxels[:, 1], voxels[:, 2]] += 1 else: unique_voxels, counts = np.unique(voxels, return_counts=True, axis=0) unique_voxels = unique_voxels.astype(np.int32) voxel_grid[t_i, unique_voxels[:, 0], unique_voxels[:, 1], unique_voxels[:, 2]] += counts return voxel_grid def get_pose(self, idx): pose = np.zeros((4, 4)) pose[3, 3] = 1 pose[:3, :4] = self._poses[idx].reshape(3, 4) return pose def __getitem__(self, idx): # -1 indicates no data # the final index is the output idx_range = self.find_horizon(idx) if self.transform_pose: ego_pose = self.get_pose(idx_range[-1]) to_ego = np.linalg.inv(ego_pose) if self.voxelize_input: voxel_input = np.zeros((idx_range.shape[0], int(self.grid_dims[0]), int(self.grid_dims[1]), int(self.grid_dims[2])), dtype=np.float32) t_i = 0 for i in idx_range: if i == -1: # Zero pad points = np.zeros((1, 3), dtype=np.float32) else: points = np.fromfile(self._velodyne_list[i],dtype=np.float32).reshape(-1, 4)[:, :3] if self.transform_pose: to_world = self.get_pose(i) relative_pose = np.matmul(to_ego, to_world) points = np.dot(relative_pose[:3, :3], points.T).T + relative_pose[:3, 3] valid_point_mask= np.all((points < self.max_bound) & (points >= self.min_bound), axis=1) valid_points = points[valid_point_mask, :] if self.voxelize_input: voxel_input = self.points_to_voxels(voxel_input, valid_points, t_i) t_i += 1 if self.get_gt: output = np.fromfile(self._eval_labels[idx_range[-1]],dtype=np.uint32).reshape(self._eval_size).astype(np.uint8) counts = np.fromfile(self._eval_counts[idx_range[-1]],dtype=np.float32).reshape(self._eval_size) else: output = None counts = None if self.voxelize_input and self.random_flips: # X flip if np.random.randint(2): output = np.flip(output, axis=0) counts = np.flip(counts, axis=0) voxel_input = np.flip(voxel_input, axis=1) # Because there is a time dimension # Y Flip if np.random.randint(2): output = np.flip(output, axis=1) counts = np.flip(counts, axis=1) voxel_input = np.flip(voxel_input, axis=2) # Because there is a time dimension if self.remap: output = LABELS_REMAP[output].astype(np.uint8) return voxel_input, output, counts # no enough frames def find_horizon(self, idx): end_idx = idx idx_range = np.arange(idx-self._num_frames, idx)+1 diffs = np.asarray([int(self._frames_list[end_idx]) - int(self._frames_list[i]) for i in idx_range]) good_difs = -1 * (np.arange(-self._num_frames, 0) + 1) idx_range[good_difs != diffs] = -1 return idx_range ================================================ FILE: datasets/data.py ================================================ import os import math import torch import numpy as np from torch.utils.data import DataLoader from datasets.carla_dataset import * dataset_choices = {'carla', 'kitti'} def get_data_id(args): return '{}'.format(args.dataset) def get_class_weights(freq): ''' Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf) ''' epsilon_w = 0.001 # eps to avoid zero division weights = torch.from_numpy(1 / np.log(freq + epsilon_w)) return weights def get_data(args): assert args.dataset in dataset_choices if args.dataset == 'carla': train_dir = os.path.join(args.dataset_dir, "Train") val_dir = os.path.join(args.dataset_dir, "Val") test_dir = os.path.join(args.dataset_dir, "Test") x_dim = 128 y_dim = 128 z_dim = 8 data_shape = [x_dim, y_dim, z_dim] args.data_shape= data_shape binary_counts = True transform_pose = True remap = True if remap: class_frequencies = remap_frequencies_cartesian args.num_classes = 11 else: args.num_classes = 23 comp_weights = get_class_weights(class_frequencies).to(torch.float32) seg_weights = get_class_weights(class_frequencies[1:]).to(torch.float32) train_ds = CarlaDataset(directory=train_dir, random_flips=True, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose) coor_ranges = train_ds._eval_param['min_bound'] + train_ds._eval_param['max_bound'] voxel_sizes = [abs(coor_ranges[3] - coor_ranges[0]) / x_dim, abs(coor_ranges[4] - coor_ranges[1]) / y_dim, abs(coor_ranges[5] - coor_ranges[2]) / z_dim] # since BEV val_ds = CarlaDataset(directory=val_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose) test_ds = CarlaDataset(directory=test_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose) if args is not None and args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds, shuffle=False) train_iters = len(train_sampler) // args.batch_size val_iters = len(val_sampler) // args.batch_size else: train_sampler = None val_sampler = None train_iters = len(train_ds) // args.batch_size val_iters = len(val_ds) // args.batch_size dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, collate_fn=train_ds.collate_fn, num_workers=args.num_workers) dataloader_val = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, sampler=val_sampler, collate_fn=val_ds.collate_fn, num_workers=args.num_workers) dataloader_test = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=test_ds.collate_fn, num_workers=args.num_workers) else: raise NotImplementedError("Wrong `dataset` has come. Other datasets are not supported.") return dataloader, dataloader_val, dataloader_test, args.num_classes, comp_weights, seg_weights, train_sampler ================================================ FILE: layers/Ablation/wo_diffusion.py ================================================ import torch from torch import nn from torch.nn import functional as F import numpy as np from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder class wo_diff(torch.nn.Module): def __init__(self, args, multi_criterion) -> None: super(wo_diff, self).__init__() self.args = args if self.args.dataset == 'kitti': init_size = args.init_size elif self.args.dataset == 'carla': init_size = args.init_size self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention) self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention) self.multi_criterion = multi_criterion def device(self): return self.encoder.device def forward(self, x, input_ten): latent = self.encoder(input_ten, out_conv=False) recons = self.decoder(latent, in_conv=False) recons_loss = self.multi_criterion(recons, x) return recons_loss def sample(self, x): latent = self.encoder(x, out_conv=False) recons = self.decoder(latent, in_conv=False) recons = recons.argmax(1) return recons ================================================ FILE: layers/Latent_Level/stage1/model.py ================================================ import numpy as np import math import torch from torch import nn import torch.nn.functional as F from einops import rearrange, reduce, repeat from torch import nn, einsum def conv3x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False) def conv1x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False) def conv1x3x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False) def conv3x1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False) def conv3x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride) class Asymmetric_Residual_Block(nn.Module): def __init__(self, in_filters, out_filters): super(Asymmetric_Residual_Block, self).__init__() self.conv1 = conv1x3x3(in_filters, out_filters) self.act1 = nn.LeakyReLU() self.conv1_2 = conv3x1x3(out_filters, out_filters) self.act1_2 = nn.LeakyReLU() self.conv2 = conv3x1x3(in_filters, out_filters) self.act2 = nn.LeakyReLU() self.conv3 = conv1x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() if in_filters<32 : self.GroupNorm = nn.GroupNorm(8, in_filters) self.bn0 = nn.GroupNorm(8, out_filters) self.bn0_2 = nn.GroupNorm(8, out_filters) self.bn1 = nn.GroupNorm(8, out_filters) self.bn2 = nn.GroupNorm(8, out_filters) else : self.GroupNorm = nn.GroupNorm(32, in_filters) self.bn0 = nn.GroupNorm(32, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.bn1 = nn.GroupNorm(32, out_filters) self.bn2 = nn.GroupNorm(32, out_filters) def forward(self, x): shortcut = self.conv1(x) shortcut = self.act1(shortcut) shortcut = self.bn0(shortcut) shortcut = self.conv1_2(shortcut) shortcut = self.act1_2(shortcut) shortcut = self.bn0_2(shortcut) resA = self.conv2(x) resA = self.act2(resA) resA = self.bn1(resA) resA = self.conv3(resA) resA = self.act3(resA) resA = self.bn2(resA) resA += shortcut return resA class DownBlock(nn.Module): def __init__(self, in_filters, out_filters, pooling=True, drop_out=True, height_pooling=False): super(DownBlock, self).__init__() self.pooling = pooling self.drop_out = drop_out self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters) if pooling: if height_pooling: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False) else: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False) def forward(self, x): resA = self.residual_block(x) if self.pooling: resB = self.pool(resA) return resB, resA else: return resA class UpBlock(nn.Module): def __init__(self, in_filters, out_filters, height_pooling): super(UpBlock, self).__init__() # self.drop_out = drop_out self.trans_dilao = conv3x3x3(in_filters, out_filters) self.trans_act = nn.LeakyReLU() self.conv1 = conv1x3x3(out_filters, out_filters) self.act1 = nn.LeakyReLU() self.conv2 = conv3x1x3(out_filters, out_filters) self.act2 = nn.LeakyReLU() self.conv3 = conv3x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() if out_filters<32 : self.trans_bn = nn.GroupNorm(8, out_filters) self.bn1 = nn.GroupNorm(8, out_filters) self.bn2 = nn.GroupNorm(8, out_filters) self.bn3 = nn.GroupNorm(8, out_filters) else : self.trans_bn = nn.GroupNorm(32, out_filters) self.bn1 = nn.GroupNorm(32, out_filters) self.bn2 = nn.GroupNorm(32, out_filters) self.bn3 = nn.GroupNorm(32, out_filters) if height_pooling : self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1) else : self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1) def forward(self, x, skip=False): if skip : x, residual = x upA = self.trans_dilao(x) upA = self.trans_act(upA) upA = self.trans_bn(upA) upA = self.up_subm(upA) if skip : upA += residual upE = self.conv1(upA) upE = self.act1(upE) upE = self.bn1(upE) upE = self.conv2(upE) upE = self.act2(upE) upE = self.bn2(upE) upE = self.conv3(upE) upE = self.act3(upE) upE = self.bn3(upE) return upE class DDCM(nn.Module): def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1): super(DDCM, self).__init__() self.conv1 = conv3x1x1(in_filters, out_filters) self.act1 = nn.Sigmoid() self.conv1_2 = conv1x3x1(in_filters, out_filters) self.act1_2 = nn.Sigmoid() self.conv1_3 = conv1x1x3(in_filters, out_filters) self.act1_3 = nn.Sigmoid() if in_filters<32 : self.bn0 = nn.GroupNorm(8, out_filters) self.bn0_2 = nn.GroupNorm(8, out_filters) self.bn0_3 = nn.GroupNorm(8, out_filters) else : self.bn0 = nn.GroupNorm(32, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.bn0_3 = nn.GroupNorm(32, out_filters) def forward(self, x): shortcut = self.conv1(x) shortcut = self.bn0(shortcut) shortcut = self.act1(shortcut) shortcut2 = self.conv1_2(x) shortcut2 = self.bn0_2(shortcut2) shortcut2 = self.act1_2(shortcut2) shortcut3 = self.conv1_3(x) shortcut3 = self.bn0_3(shortcut3) shortcut3 = self.act1_3(shortcut3) shortcut = shortcut + shortcut2 + shortcut3 shortcut = shortcut * x return shortcut def l2norm(t): return F.normalize(t, dim = -1) class Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_qkv = conv1x1(dim, dim*3, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x): b, c, h, w, Z = x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class C_Encoder(nn.Module): def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True): super(C_Encoder, self).__init__() self.nclasses = nclasses self.args = args self.l_size = l_size self.attention = attention self.embedding = nn.Embedding(nclasses, init_size) self.A = Asymmetric_Residual_Block(init_size, init_size) self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True) self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True) self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False) self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False) if self.l_size == '32322': self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.attention = Attention(4 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.out = nn.Conv3d(4 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True) elif self.l_size == '16162': self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size) self.attention = Attention(8 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size) self.out = nn.Conv3d(8 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True) elif self.l_size == '882': self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) self.attention = Attention(16 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) self.out = nn.Conv3d(16 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True) else: raise NotImplementedError("Unsupported `l_size` has come") def forward(self, x, out_conv=True): x = self.embedding(x) x = x.permute(0, 4, 1, 2, 3) x = self.A(x) x, down1b = self.downBlock1(x) x, down2b = self.downBlock2(x) if self.l_size == '882': x, down3b = self.downBlock3(x) x, down4b = self.downBlock4(x) elif self.l_size == '16162': x, down3b = self.downBlock3(x) if self.attention : x = self.midBlock1(x) # (4, 128, 32, 32, 2) x = self.attention(x) x = self.midBlock2(x) # (4, 128, 32, 32, 2) if out_conv : x = self.out(x) return x class C_Decoder(nn.Module): def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True): super(C_Decoder, self).__init__() self.nclasses = nclasses self.args = args self.l_size = l_size self.attention = attention if l_size == '882': self.conv_in = nn.Conv3d(nclasses, 16 * init_size, kernel_size=3, stride=1, padding=1,bias=True) self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) self.attention = Attention(16 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) elif l_size == '16162': self.conv_in = nn.Conv3d(nclasses, 8 * init_size, kernel_size=3, stride=1, padding=1,bias=True) self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size) self.attention = Attention(8 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size) elif (l_size =='32322'): self.conv_in = nn.Conv3d(nclasses, 4 * init_size, kernel_size=3, stride=1, padding=1,bias=True) self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.attention = Attention(4 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False) self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False) self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True) self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True) self.DDCM = DDCM(2 * init_size, 2 * init_size) self.logits = nn.Conv3d(4 * init_size, self.nclasses, kernel_size=3, stride=1, padding=1, bias=True) def forward(self, x, in_conv=True): if in_conv : x = self.conv_in(x) if self.attention : x = self.midBlock1(x) x = self.attention(x) x = self.midBlock2(x) if self.l_size == '882': x = self.upBlock4(x) x = self.upBlock3(x) elif self.l_size == '16162': x = self.upBlock3(x) x = self.upBlock2(x) up1 = self.upBlock1(x) up0 = self.DDCM(up1) up = torch.cat((up1, up0), 1) logits = self.logits(up) return logits class Completion(nn.Module): def __init__(self, args, num_class = 11, init_size=32): super(Completion, self).__init__() self.args = args self.num_class = num_class self.init_size = init_size self.embedding = nn.Embedding(self.num_class, init_size) self.A = Asymmetric_Residual_Block(init_size, init_size) self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True) self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True) self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False) self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False) self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) self.attention = Attention(16 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size) self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False) self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False) self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True) self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True) self.DDCM = DDCM(2 * init_size, 2 * init_size) self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True) def forward(self, x): x = self.embedding(x) x = x.permute(0, 4, 1, 2, 3) x = self.A(x) down1c, down1b = self.downBlock1(x) down2c, down2b = self.downBlock2(down1c) down3c, down3b = self.downBlock3(down2c) down4c, down4b = self.downBlock4(down3c) down4c = self.midBlock1(down4c) down4c = self.attention(down4c) down4c = self.midBlock2(down4c) up4 = self.upBlock4((down4c, down4b), skip=True) up3 = self.upBlock3((up4, down3b), skip=True) up2 = self.upBlock2((up3, down2b), skip=True) up1 = self.upBlock1((up2, down1b), skip=True) up0 = self.DDCM(up1) up = torch.cat((up1, up0), 1) logits = self.logits(up) return logits ================================================ FILE: layers/Latent_Level/stage1/vector_quantizer.py ================================================ import torch from torch import nn from torch.nn import functional as F class VectorQuantizer(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25): super(VectorQuantizer, self).__init__() self.K = num_embeddings self.D = embedding_dim self.beta = beta self.embedding = nn.Embedding(self.K, self.D) self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) def forward(self, z: torch.tensor, point=False) -> torch.tensor: # latents (8, 128, 8, 8, 2) z = z.permute(0, 2, 3, 4, 1).contiguous() # [B x D x H x W x Z] -> [B x H x W x Z x D] latents_shape = z.shape # ( 8, 8, 8, 2, 128 ) flat_latents = z.view(-1, self.D) # [BHWZ x D] = [1024, 128] # Compute L2 distance between latents and embedding weights dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - \ 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHWZ x K] # Get the encoding that has the min distance min_encoding_indices = torch.argmin(dist, dim=1).unsqueeze(1) # [BHWZ, 1] z_q = self.embedding(min_encoding_indices).view(z.shape) # Compute the VQ Losses commitment_loss = F.mse_loss(z_q.detach(), z) embedding_loss = F.mse_loss(z_q, z.detach()) if point : vq_loss = commitment_loss * self.beta else : vq_loss = commitment_loss * self.beta + embedding_loss # Add the residue back to the latents z_q = z + (z_q - z).detach() return z_q.permute(0, 4, 1, 2, 3).contiguous(), vq_loss, min_encoding_indices, latents_shape def codebook_to_embedding(self, encoding_inds, latents_shape): # latents (16, 512, 8, 8, 2) # Convert to one-hot encodings z_q = self.embedding(encoding_inds).view(latents_shape) return z_q.permute(0, 4, 1, 2, 3).contiguous() ================================================ FILE: layers/Latent_Level/stage1/vqvae.py ================================================ import torch from torch import nn from torch.nn import functional as F import numpy as np import math from utils.loss import lovasz_softmax from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder from layers.Latent_Level.stage1.vector_quantizer import VectorQuantizer class vqvae(torch.nn.Module): def __init__(self, args, multi_criterion) -> None: super(vqvae, self).__init__() self.args = args init_size = args.init_size embedding_dim = int(self.args.num_classes) self.VQ = VectorQuantizer(num_embeddings = int(self.args.num_classes)*int(self.args.vq_size), embedding_dim = embedding_dim) self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention) self.quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1) self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention) self.post_quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1) self.multi_criterion = multi_criterion def device(self): return self.encoder.device def encode(self, x): latent = self.encoder(x) latent = self.quant_conv(latent) return latent def vector_quantize(self, latent): quantized_latent, vq_loss, quantized_latent_ind, latents_shape = self.VQ(latent) return quantized_latent, vq_loss, quantized_latent_ind, latents_shape def coodbook(self,quantized_latent_ind, latents_shape): quantized_latent = self.VQ.codebook_to_embedding(quantized_latent_ind.view(-1,1), latents_shape) return quantized_latent def decode(self, quantized_latent): quantized_latent = self.post_quant_conv(quantized_latent) recons = self.decoder(quantized_latent) return recons def forward(self, x, input_ten): latent = self.encode(x) quantized_latent, vq_loss, _, _ = self.vector_quantize(latent) recons = self.decode(quantized_latent) recons_loss = self.multi_criterion(recons, x) loss = recons_loss + vq_loss return loss def sample(self, x): latent = self.encode(x) quantized_latent, _, _, _ = self.vector_quantize(latent) recons = self.decode(quantized_latent) recons = recons.argmax(1) return recons ================================================ FILE: layers/Latent_Level/stage2/Gen_diffusion.py ================================================ import torch import torch.nn.functional as F import numpy as np import math from inspect import isfunction from layers.Latent_Level.stage2.gen_denoise import Denoise """ Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ eps = 1e-8 def sum_except_batch(x, num_dims=1): return x.reshape(*x.shape[:num_dims], -1).sum(-1) def log_1_min_a(a): return torch.log(1 - a.exp() + 1e-40) def log_add_exp(a, b): maximum = torch.max(a, b) return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) def exists(x): return x is not None def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def default(val, d): if exists(val): return val return d() if isfunction(d) else d def log_categorical(log_x_start, log_prob): return (log_x_start.exp() * log_prob).sum(dim=1) def index_to_log_onehot(x, num_classes): assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}' x_onehot = F.one_hot(x, num_classes) permute_order = (0, -1) + tuple(range(1, len(x.size()))) x_onehot = x_onehot.permute(permute_order) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def log_onehot_to_index(log_x): return log_x.argmax(1) def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) alphas = np.clip(alphas, a_min=0.001, a_max=1.) alphas = np.sqrt(alphas) return alphas class latent_diffusion(torch.nn.Module): def __init__(self, args, VAE_DENSE, multi_criterion, auxiliary_loss_weight=0.0005, adaptive_auxiliary_loss=True): super(latent_diffusion, self).__init__() self.args = args self.num_classes = self.args.num_classes * self.args.vq_size self.denoise = Denoise(args= self.args, num_class = self.num_classes) self.num_timesteps = self.args.diffusion_steps self.auxiliary_loss_weight = auxiliary_loss_weight self.adaptive_auxiliary_loss = adaptive_auxiliary_loss self.VAE_DENSE = VAE_DENSE self.multi_criterion = multi_criterion alphas = cosine_beta_schedule(self.num_timesteps ) alphas = torch.tensor(alphas.astype('float64')) log_alpha = np.log(alphas) log_cumprod_alpha = np.cumsum(log_alpha) log_1_min_alpha = log_1_min_a(log_alpha) log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 # Convert to float32 and register buffers. self.register_buffer('log_alpha', log_alpha.float()) self.register_buffer('log_1_min_alpha', log_1_min_alpha.float()) self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float()) self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float()) self.register_buffer('Lt_history', torch.zeros(self.num_timesteps )) self.register_buffer('Lt_count', torch.zeros(self.num_timesteps )) def device(self): return self.denoise.device def multinomial_kl(self, log_prob1, log_prob2): kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) return kl def q_pred_one_timestep(self, log_x_t, t): log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) # alpha_t * E[xt] + (1 - alpha_t) 1 / K log_probs = log_add_exp( log_x_t + log_alpha_t, log_1_min_alpha_t - np.log(self.num_classes) ) return log_probs def q_pred(self, log_x_start, t): log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) log_probs = log_add_exp( log_x_start + log_cumprod_alpha_t, log_1_min_cumprod_alpha - np.log(self.num_classes) ) return log_probs def predict_start(self, log_x_t, t): x_t = log_onehot_to_index(log_x_t) out = self.denoise(x_t, t) assert out.size(0) == x_t.size(0) assert out.size(1) == self.num_classes assert out.size()[2:] == x_t.size()[1:] log_pred = F.log_softmax(out, dim=1) return log_pred def q_posterior(self, log_x_start, log_x_t, t): # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) # where q(xt | xt-1, x0) = q(xt | xt-1). t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) num_axes = (1,) * (len(log_x_start.size()) - 1) t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! # Not very easy to see why this is true. But it is :) unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) log_EV_xtmin_given_xt_given_xstart = \ unnormed_logprobs \ - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True) return log_EV_xtmin_given_xt_given_xstart def p_pred(self, log_x, t): log_x0_recon = self.predict_start(log_x, t=t) log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t) return log_model_pred, log_x0_recon def log_sample_categorical(self, logits): uniform = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) sample = (gumbel_noise + logits).argmax(dim=1) log_sample = index_to_log_onehot(sample, self.num_classes) return log_sample def q_sample(self, log_x_start, t): log_EV_qxt_x0 = self.q_pred(log_x_start, t) log_sample = self.log_sample_categorical(log_EV_qxt_x0) return log_sample def kl_prior(self, log_x_start): b = log_x_start.size(0) device = log_x_start.device ones = torch.ones(b, device=device).long() log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) def sample_time(self, b, device, method='uniform'): if method == 'importance': if not (self.Lt_count > 10).all(): return self.sample_time(b, device, method='uniform') Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. pt_all = Lt_sqrt / Lt_sqrt.sum() t = torch.multinomial(pt_all, num_samples=b, replacement=True) pt = pt_all.gather(dim=0, index=t) return t, pt elif method == 'uniform': t = torch.randint(0, self.num_timesteps, (b,), device=device).long() pt = torch.ones_like(t).float() / self.num_timesteps return t, pt else: raise ValueError def forward(self, x, input_data): b, device = x.size(0), x.device self.shape = x.size()[1:] latent = self.VAE_DENSE.encode(x) _, _, dense_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent) reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]] t, pt = self.sample_time(b, device, 'importance') log_x_start = index_to_log_onehot(dense_ind.view(reshape_size), self.num_classes) log_x_t = self.q_sample(log_x_start=log_x_start, t=t) # log_x_t : (8,551,8,8,2) log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t) kl = self.multinomial_kl(log_true_prob, log_model_prob) kl = sum_except_batch(kl) decoder_nll = -log_categorical(log_x_start, log_model_prob) decoder_nll = sum_except_batch(decoder_nll) mask = (t == torch.zeros_like(t)).float() kl_loss = mask * decoder_nll + (1. - mask) * kl if self.training: Lt2 = kl_loss.pow(2) Lt2_prev = self.Lt_history.gather(dim=0, index=t) new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach() self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history) self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2)) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl loss = kl_loss / pt + kl_prior kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:]) kl_aux = sum_except_batch(kl_aux) kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux if self.adaptive_auxiliary_loss: addition_loss_weight = (1-t/self.num_timesteps) + 1.0 else: addition_loss_weight = 1.0 aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt loss += aux_loss loss = -loss.sum() / (math.log(2) * dense_ind.view(reshape_size).shape.numel()) x0 = log_onehot_to_index(F.log_softmax(log_x0_recon, dim=1)) return -loss def sample(self, x): device = self.log_alpha.device self.shape = x.size()[1:] x = torch.randint(self.args.num_classes, size=x.size()).to(device) latent = self.VAE_DENSE.encode(x) _, _, sparse_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent) reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]] log_z = index_to_log_onehot(sparse_ind.view(reshape_size), self.num_classes) # log_x_t : (8,551,8,8,2) for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long) log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t) uniform = torch.rand_like(log_model_prob) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) pre_sample = gumbel_noise + log_model_prob sample = pre_sample.argmax(dim=1) # (32, 1, 32, 64) log_z = index_to_log_onehot(sample, self.num_classes) vq_ind = log_onehot_to_index(log_z) vq_latent = self.VAE_DENSE.coodbook(vq_ind.view(-1,1), latents_shape) recons = self.VAE_DENSE.decode(vq_latent) recons = recons.argmax(1) return recons ================================================ FILE: layers/Latent_Level/stage2/gen_denoise.py ================================================ import math from mimetypes import init import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from einops import rearrange, reduce, repeat from torch import nn, einsum def conv3x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False) def conv1x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False) def conv1x3x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False) def conv3x1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False) def conv3x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride) class Asymmetric_Residual_Block(nn.Module): def __init__(self, in_filters, out_filters, time_filters=128): super(Asymmetric_Residual_Block, self).__init__() self.GroupNorm = nn.GroupNorm(32, in_filters) self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) self.bn0 = nn.GroupNorm(32, out_filters) self.act1 = nn.LeakyReLU() self.conv1_2 = conv3x1x3(out_filters, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.act1_2 = nn.LeakyReLU() self.conv2 = conv3x1x3(in_filters, out_filters) self.act2 = nn.LeakyReLU() self.bn1 = nn.GroupNorm(32, out_filters) self.conv3 = conv1x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() self.bn2 = nn.GroupNorm(32, out_filters) def forward(self, x, t): t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) x = self.GroupNorm(x) * (1 + scale) + shift shortcut = self.conv1(x) shortcut = self.act1(shortcut) shortcut = self.bn0(shortcut) shortcut = self.conv1_2(shortcut) shortcut = self.act1_2(shortcut) shortcut = self.bn0_2(shortcut) resA = self.conv2(x) resA = self.act2(resA) resA = self.bn1(resA) resA = self.conv3(resA) resA = self.act3(resA) resA = self.bn2(resA) resA += shortcut return resA class DDCM(nn.Module): def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1): super(DDCM, self).__init__() self.conv1 = conv3x1x1(in_filters, out_filters) self.bn0 = nn.GroupNorm(32, out_filters) self.act1 = nn.Sigmoid() self.conv1_2 = conv1x3x1(in_filters, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.act1_2 = nn.Sigmoid() self.conv1_3 = conv1x1x3(in_filters, out_filters) self.bn0_3 = nn.GroupNorm(32, out_filters) self.act1_3 = nn.Sigmoid() def forward(self, x): shortcut = self.conv1(x) shortcut = self.bn0(shortcut) shortcut = self.act1(shortcut) shortcut2 = self.conv1_2(x) shortcut2 = self.bn0_2(shortcut2) shortcut2 = self.act1_2(shortcut2) shortcut3 = self.conv1_3(x) shortcut3 = self.bn0_3(shortcut3) shortcut3 = self.act1_3(shortcut3) shortcut = shortcut + shortcut2 + shortcut3 shortcut = shortcut * x return shortcut def l2norm(t): return F.normalize(t, dim = -1) class Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_qkv = conv1x1(dim, dim*3, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x): b, c, h, w, Z = x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class Cross_Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_q = conv1x1(dim, dim, stride=1) self.to_k = conv1x1(dim, dim, stride=1) self.to_v = conv1x1(dim, dim, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x, cond_x): b, c, h, w, Z = x.shape q = self.to_q(x) k = self.to_k(cond_x) v = self.to_v(cond_x) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v)) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class DownBlock(nn.Module): def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1, pooling=True, drop_out=True, height_pooling=False): super(DownBlock, self).__init__() self.pooling = pooling self.drop_out = drop_out self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters) if pooling: if height_pooling: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2, padding=1, bias=False) else: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1), padding=1, bias=False) def forward(self, x, t): resA = self.residual_block(x, t) if self.pooling: resB = self.pool(resA) return resB, resA else: return resA class UpBlock(nn.Module): def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4): super(UpBlock, self).__init__() # self.drop_out = drop_out self.trans_dilao = conv3x3x3(in_filters, in_filters) self.trans_act = nn.LeakyReLU() self.trans_bn = nn.GroupNorm(32, in_filters) self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) self.act1 = nn.LeakyReLU() self.bn1 = nn.GroupNorm(32, out_filters) self.conv2 = conv3x1x3(out_filters, out_filters) self.act2 = nn.LeakyReLU() self.bn2 = nn.GroupNorm(32, out_filters) self.conv3 = conv3x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() self.bn3 = nn.GroupNorm(32, out_filters) if height_pooling : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1) else : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1) def forward(self, x, residual, t): upA = self.trans_dilao(x) upA = self.trans_act(upA) t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) upA = self.trans_bn(upA) * (1 + scale) + shift ## upsample upA = self.up_subm(upA) upA += residual upE = self.conv1(upA) upE = self.act1(upE) upE = self.bn1(upE) upE = self.conv2(upE) upE = self.act2(upE) upE = self.bn2(upE) upE = self.conv3(upE) upE = self.act3(upE) upE = self.bn3(upE) return upE def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding class Denoise(nn.Module): def __init__(self, args, num_class = 11, init_size=32, discrete=True): super(Denoise, self).__init__() self.args = args self.discrete = discrete self.num_class = num_class self.init_size = init_size self.time_size = init_size*4 self.time_embed = nn.Sequential( nn.Linear(init_size, self.time_size), nn.SiLU(), nn.Linear(self.time_size, self.time_size), ) self.embedding = nn.Embedding(self.num_class, init_size) self.conv_in = nn.Conv3d(init_size, init_size, kernel_size=1, stride=1) self.A = Asymmetric_Residual_Block(init_size, init_size) self.midBlock1_1 = Asymmetric_Residual_Block(init_size, 2 * init_size) self.attention1 = Attention(2 * init_size, 4) self.midBlock1_2 = Asymmetric_Residual_Block(2 * init_size, 2 * init_size) self.downBlock2 = DownBlock(init_size*2, 2 * init_size, 0.2, height_pooling=False) self.downBlock3 = DownBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=False) self.midBlock2_1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.attention2 = Attention(4 * init_size, 4) self.midBlock2_2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size) self.upBlock0 = UpBlock(4 * init_size, 2 * init_size, height_pooling=False) self.upBlock1 = UpBlock(2 * init_size, init_size, height_pooling=False) self.midBlock3_1 = Asymmetric_Residual_Block(init_size, init_size) self.attention3 = Attention(init_size, 4) self.midBlock3_2 = Asymmetric_Residual_Block(init_size, init_size) self.DDCM = DDCM(init_size, init_size) self.logits = nn.Sequential( nn.Conv3d(2 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True), ) def forward(self, x, t): x = self.embedding(x) x = x.permute(0, 4, 1, 2, 3) x = self.conv_in(x) t = self.time_embed(timestep_embedding(t, self.init_size)) ret = self.A(x, t) mid1 = self.midBlock1_1(ret, t) att = self.attention1(mid1) mid2 = self.midBlock1_2(att, t) down1c, down1b = self.downBlock2(mid2, t) down2c, down2b = self.downBlock3(down1c, t) d_mid2 = self.midBlock2_1(down2c, t) d_att = self.attention2(d_mid2) d_mid1 = self.midBlock2_2(d_att, t) up3e = self.upBlock0(d_mid1, down2b, t) up2e = self.upBlock1(up3e, down1b, t) u_mid2 = self.midBlock3_1(up2e, t) u_att = self.attention3(u_mid2) u_mid1 = self.midBlock3_2(u_att, t) up0e = self.DDCM(u_mid1) up0e = torch.cat((up0e, up2e), 1) logits = self.logits(up0e) return logits ================================================ FILE: layers/Voxel_Level/Con_Diffusion.py ================================================ import torch import torch.nn.functional as F import numpy as np import math from inspect import isfunction from layers.Voxel_Level.denoise import Denoise from utils.loss import * """ Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ eps = 1e-8 def sum_except_batch(x, num_dims=1): return x.reshape(*x.shape[:num_dims], -1).sum(-1) def log_1_min_a(a): return torch.log(1 - a.exp() + 1e-40) def log_add_exp(a, b): maximum = torch.max(a, b) return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) def exists(x): return x is not None def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def default(val, d): if exists(val): return val return d() if isfunction(d) else d def log_categorical(log_x_start, log_prob): return (log_x_start.exp() * log_prob).sum(dim=1) def index_to_log_onehot(x, num_classes): assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}' x_onehot = F.one_hot(x, num_classes) permute_order = (0, -1) + tuple(range(1, len(x.size()))) x_onehot = x_onehot.permute(permute_order) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def log_onehot_to_index(log_x): return log_x.argmax(1) def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) alphas = np.clip(alphas, a_min=0.001, a_max=1.) alphas = np.sqrt(alphas) return alphas class Con_Diffusion(torch.nn.Module): def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True): super(Con_Diffusion, self).__init__() #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps) self.args = args self.num_classes = self.args.num_classes self.num_timesteps = self.args.diffusion_steps self.recon_loss = self.args.recon_loss if args.dataset == 'carla': self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes) elif args.dataset=='kitti': self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes, init_size=16) self.auxiliary_loss_weight = auxiliary_loss_weight self.adaptive_auxiliary_loss = adaptive_auxiliary_loss self.multi_criterion = multi_criterion alphas = cosine_beta_schedule(self.num_timesteps ) alphas = torch.tensor(alphas.astype('float64')) log_alpha = np.log(alphas) log_cumprod_alpha = np.cumsum(log_alpha) log_1_min_alpha = log_1_min_a(log_alpha) log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 # Convert to float32 and register buffers. self.register_buffer('log_alpha', log_alpha.float()) self.register_buffer('log_1_min_alpha', log_1_min_alpha.float()) self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float()) self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float()) self.register_buffer('Lt_history', torch.zeros(self.num_timesteps )) self.register_buffer('Lt_count', torch.zeros(self.num_timesteps )) def device(self): return self.denoise_fn.device def multinomial_kl(self, log_prob1, log_prob2): kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) return kl def q_pred_one_timestep(self, log_x_t, t): log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) # alpha_t * E[xt] + (1 - alpha_t) 1 / K log_probs = log_add_exp( log_x_t + log_alpha_t, log_1_min_alpha_t - np.log(self.num_classes) ) return log_probs def q_pred(self, log_x_start, t): log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) log_probs = log_add_exp( log_x_start + log_cumprod_alpha_t, log_1_min_cumprod_alpha - np.log(self.num_classes) ) return log_probs def predict_start(self, log_x_t, t, cond): x_t = log_onehot_to_index(log_x_t) out = self._denoise_fn(x_t, cond, t) log_pred = F.log_softmax(out, dim=1) return log_pred def q_posterior(self, log_x_start, log_x_t, t): # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) # where q(xt | xt-1, x0) = q(xt | xt-1). t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) num_axes = (1,) * (len(log_x_start.size()) - 1) t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! # Not very easy to see why this is true. But it is :) unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True) return log_EV_xtmin_given_xt_given_xstart def p_pred(self, log_x, t, cond): log_x0_recon = self.predict_start(log_x, t, cond) log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t) return log_model_pred, log_x0_recon def log_sample_categorical(self, logits): uniform = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) sample = (gumbel_noise + logits).argmax(dim=1) log_sample = index_to_log_onehot(sample, self.num_classes) return log_sample def q_sample(self, log_x_start, t): log_EV_qxt_x0 = self.q_pred(log_x_start, t) log_sample = self.log_sample_categorical(log_EV_qxt_x0) return log_sample def kl_prior(self, log_x_start): b = log_x_start.size(0) device = log_x_start.device ones = torch.ones(b, device=device).long() log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) def sample_time(self, b, device, method='uniform'): if method == 'importance': if not (self.Lt_count > 10).all(): return self.sample_time(b, device, method='uniform') Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. pt_all = Lt_sqrt / Lt_sqrt.sum() t = torch.multinomial(pt_all, num_samples=b, replacement=True) pt = pt_all.gather(dim=0, index=t) return t, pt elif method == 'uniform': t = torch.randint(0, self.num_timesteps, (b,), device=device).long() pt = torch.ones_like(t).float() / self.num_timesteps return t, pt else: raise ValueError def forward(self, x, voxel_input): b, device = x.size(0), x.device self.shape = x.size()[1:] t, pt = self.sample_time(b, device, 'importance') log_x_start = index_to_log_onehot(x, self.num_classes) log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8) log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t, cond=voxel_input) kl = self.multinomial_kl(log_true_prob, log_model_prob) kl = sum_except_batch(kl) decoder_nll = -log_categorical(log_x_start, log_model_prob) decoder_nll = sum_except_batch(decoder_nll) mask = (t == torch.zeros_like(t)).float() kl_loss = mask * decoder_nll + (1. - mask) * kl if self.training: Lt2 = kl_loss.pow(2) Lt2_prev = self.Lt_history.gather(dim=0, index=t) new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach() self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history) self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2)) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl loss = kl_loss / pt + kl_prior kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:]) kl_aux = sum_except_batch(kl_aux) if self.recon_loss : kl_aux += self.multi_criterion(log_x0_recon.exp(), x) #kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x) kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux if self.adaptive_auxiliary_loss: addition_loss_weight = (1-t/self.num_timesteps) + 1.0 else: addition_loss_weight = 1.0 aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt loss += aux_loss loss = -loss.sum() / (self.shape[0]*self.shape[1]) #loss += seg_loss return -loss def sample(self, voxel_input, intermediate=False): device = self.log_alpha.device self.shape = voxel_input.size()[1:] uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device) log_z = self.log_sample_categorical(uniform_logits) diffusion = [] for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long) log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t, cond=voxel_input) log_z = self.log_sample_categorical(log_model_prob) if i%10 ==0: diffusion.append(log_onehot_to_index(log_z)) result = log_onehot_to_index(log_z) if intermediate : return result, diffusion else : return result ================================================ FILE: layers/Voxel_Level/Gen_Diffusion.py ================================================ import torch import torch.nn.functional as F import numpy as np import math from inspect import isfunction from layers.Voxel_Level.gen_denoise import Denoise from utils.loss import * """ Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281 """ eps = 1e-8 def sum_except_batch(x, num_dims=1): return x.reshape(*x.shape[:num_dims], -1).sum(-1) def log_1_min_a(a): return torch.log(1 - a.exp() + 1e-40) def log_add_exp(a, b): maximum = torch.max(a, b) return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) def exists(x): return x is not None def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def default(val, d): if exists(val): return val return d() if isfunction(d) else d def log_categorical(log_x_start, log_prob): return (log_x_start.exp() * log_prob).sum(dim=1) def index_to_log_onehot(x, num_classes): assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}' x_onehot = F.one_hot(x, num_classes) permute_order = (0, -1) + tuple(range(1, len(x.size()))) x_onehot = x_onehot.permute(permute_order) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def log_onehot_to_index(log_x): return log_x.argmax(1) def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) alphas = np.clip(alphas, a_min=0.001, a_max=1.) alphas = np.sqrt(alphas) return alphas class Diffusion(torch.nn.Module): def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True): super(Diffusion, self).__init__() #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps) self.args = args self.num_classes = self.args.num_classes self.num_timesteps = self.args.diffusion_steps self.recon_loss = self.args.recon_loss self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes) self.auxiliary_loss_weight = auxiliary_loss_weight self.adaptive_auxiliary_loss = adaptive_auxiliary_loss self.multi_criterion = multi_criterion alphas = cosine_beta_schedule(self.num_timesteps ) alphas = torch.tensor(alphas.astype('float64')) log_alpha = np.log(alphas) log_cumprod_alpha = np.cumsum(log_alpha) log_1_min_alpha = log_1_min_a(log_alpha) log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5 # Convert to float32 and register buffers. self.register_buffer('log_alpha', log_alpha.float()) self.register_buffer('log_1_min_alpha', log_1_min_alpha.float()) self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float()) self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float()) self.register_buffer('Lt_history', torch.zeros(self.num_timesteps )) self.register_buffer('Lt_count', torch.zeros(self.num_timesteps )) def device(self): return self.denoise_fn.device def multinomial_kl(self, log_prob1, log_prob2): kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1) return kl def q_pred_one_timestep(self, log_x_t, t): log_alpha_t = extract(self.log_alpha, t, log_x_t.shape) log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape) # alpha_t * E[xt] + (1 - alpha_t) 1 / K log_probs = log_add_exp( log_x_t + log_alpha_t, log_1_min_alpha_t - np.log(self.num_classes) ) return log_probs def q_pred(self, log_x_start, t): log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape) log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape) log_probs = log_add_exp( log_x_start + log_cumprod_alpha_t, log_1_min_cumprod_alpha - np.log(self.num_classes) ) return log_probs def predict_start(self, log_x_t, t): x_t = log_onehot_to_index(log_x_t) out = self._denoise_fn(x_t, t) log_pred = F.log_softmax(out, dim=1) return log_pred def q_posterior(self, log_x_start, log_x_t, t): # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) # where q(xt | xt-1, x0) = q(xt | xt-1). t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) num_axes = (1,) * (len(log_x_start.size()) - 1) t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! # Not very easy to see why this is true. But it is :) unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True) return log_EV_xtmin_given_xt_given_xstart def p_pred(self, log_x, t): log_x0_recon = self.predict_start(log_x, t) log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t) return log_model_pred, log_x0_recon def log_sample_categorical(self, logits): uniform = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) sample = (gumbel_noise + logits).argmax(dim=1) log_sample = index_to_log_onehot(sample, self.num_classes) return log_sample def q_sample(self, log_x_start, t): log_EV_qxt_x0 = self.q_pred(log_x_start, t) log_sample = self.log_sample_categorical(log_EV_qxt_x0) return log_sample def kl_prior(self, log_x_start): b = log_x_start.size(0) device = log_x_start.device ones = torch.ones(b, device=device).long() log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) return sum_except_batch(kl_prior) def sample_time(self, b, device, method='uniform'): if method == 'importance': if not (self.Lt_count > 10).all(): return self.sample_time(b, device, method='uniform') Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. pt_all = Lt_sqrt / Lt_sqrt.sum() t = torch.multinomial(pt_all, num_samples=b, replacement=True) pt = pt_all.gather(dim=0, index=t) return t, pt elif method == 'uniform': t = torch.randint(0, self.num_timesteps, (b,), device=device).long() pt = torch.ones_like(t).float() / self.num_timesteps return t, pt else: raise ValueError def forward(self, x, voxel_input): b, device = x.size(0), x.device self.shape = x.size()[1:] t, pt = self.sample_time(b, device, 'importance') log_x_start = index_to_log_onehot(x, self.num_classes) log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8) log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t) log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t) kl = self.multinomial_kl(log_true_prob, log_model_prob) kl = sum_except_batch(kl) decoder_nll = -log_categorical(log_x_start, log_model_prob) decoder_nll = sum_except_batch(decoder_nll) mask = (t == torch.zeros_like(t)).float() kl_loss = mask * decoder_nll + (1. - mask) * kl if self.training: Lt2 = kl_loss.pow(2) Lt2_prev = self.Lt_history.gather(dim=0, index=t) new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach() self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history) self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2)) kl_prior = self.kl_prior(log_x_start) # Upweigh loss term of the kl loss = kl_loss / pt + kl_prior kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:]) kl_aux = sum_except_batch(kl_aux) '''if self.recon_loss : kl_aux += self.multi_criterion(log_x0_recon.exp(), x) kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)''' kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux if self.adaptive_auxiliary_loss: addition_loss_weight = (1-t/self.num_timesteps) + 1.0 else: addition_loss_weight = 1.0 aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt loss += aux_loss loss = -loss.sum() / (self.shape[0]*self.shape[1]) #loss += seg_loss return -loss def sample(self, voxel_input): device = self.log_alpha.device self.shape = voxel_input.size()[1:] uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device) log_z = self.log_sample_categorical(uniform_logits) for i in reversed(range(0, self.num_timesteps)): print(f'Sample timestep {i:4d}', end='\r') t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long) log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t) log_z = self.log_sample_categorical(log_model_prob) result = log_onehot_to_index(log_z) return result ================================================ FILE: layers/Voxel_Level/denoise.py ================================================ import math from mimetypes import init import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from einops import rearrange, reduce, repeat from torch import nn, einsum def conv3x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False) def conv1x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False) def conv1x3x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False) def conv3x1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False) def conv3x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride) class Asymmetric_Residual_Block(nn.Module): def __init__(self, in_filters, out_filters, time_filters=32*4): super(Asymmetric_Residual_Block, self).__init__() if in_filters<32 : self.GroupNorm = nn.GroupNorm(16, in_filters) self.bn0 = nn.GroupNorm(16, out_filters) self.bn0_2 = nn.GroupNorm(16, out_filters) self.bn1 = nn.GroupNorm(16, out_filters) self.bn2 = nn.GroupNorm(16, out_filters) else : self.GroupNorm = nn.GroupNorm(32, in_filters) self.bn0 = nn.GroupNorm(32, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.bn1 = nn.GroupNorm(32, out_filters) self.bn2 = nn.GroupNorm(32, out_filters) self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) self.act1 = nn.LeakyReLU() self.conv1_2 = conv3x1x3(out_filters, out_filters) self.act1_2 = nn.LeakyReLU() self.conv2 = conv3x1x3(in_filters, out_filters) self.act2 = nn.LeakyReLU() self.conv3 = conv1x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() def forward(self, x, t): t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) x = self.GroupNorm(x) * (1 + scale) + shift shortcut = self.conv1(x) shortcut = self.act1(shortcut) shortcut = self.bn0(shortcut) shortcut = self.conv1_2(shortcut) shortcut = self.act1_2(shortcut) shortcut = self.bn0_2(shortcut) resA = self.conv2(x) resA = self.act2(resA) resA = self.bn1(resA) resA = self.conv3(resA) resA = self.act3(resA) resA = self.bn2(resA) resA += shortcut return resA class DDCM(nn.Module): def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1): super(DDCM, self).__init__() self.conv1 = conv3x1x1(in_filters, out_filters) if in_filters<32 : self.bn0 = nn.GroupNorm(16, out_filters) self.bn0_2 = nn.GroupNorm(16, out_filters) self.bn0_3 = nn.GroupNorm(16, out_filters) else : self.bn0 = nn.GroupNorm(32, out_filters) self.bn0_2 = nn.GroupNorm(32, out_filters) self.bn0_3 = nn.GroupNorm(32, out_filters) self.act1 = nn.Sigmoid() self.conv1_2 = conv1x3x1(in_filters, out_filters) self.act1_2 = nn.Sigmoid() self.conv1_3 = conv1x1x3(in_filters, out_filters) self.act1_3 = nn.Sigmoid() def forward(self, x): shortcut = self.conv1(x) shortcut = self.bn0(shortcut) shortcut = self.act1(shortcut) shortcut2 = self.conv1_2(x) shortcut2 = self.bn0_2(shortcut2) shortcut2 = self.act1_2(shortcut2) shortcut3 = self.conv1_3(x) shortcut3 = self.bn0_3(shortcut3) shortcut3 = self.act1_3(shortcut3) shortcut = shortcut + shortcut2 + shortcut3 shortcut = shortcut * x return shortcut def l2norm(t): return F.normalize(t, dim = -1) class Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_qkv = conv1x1(dim, dim*3, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x): b, c, h, w, Z = x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class Cross_Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_q = conv1x1(dim, dim, stride=1) self.to_k = conv1x1(dim, dim, stride=1) self.to_v = conv1x1(dim, dim, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x, cond_x): b, c, h, w, Z = x.shape q = self.to_q(x) k = self.to_k(cond_x) v = self.to_v(cond_x) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v)) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class DownBlock(nn.Module): def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_size=(3, 3, 3), stride=1, pooling=True, height_pooling=False): super(DownBlock, self).__init__() self.pooling = pooling self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters=time_filters) if pooling: if height_pooling: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False) else: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False) def forward(self, x, t): resA = self.residual_block(x, t) if self.pooling: resB = self.pool(resA) return resB, resA else: return resA class UpBlock(nn.Module): def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4): super(UpBlock, self).__init__() # self.drop_out = drop_out if out_filters<32 : self.trans_bn = nn.GroupNorm(16, in_filters) self.bn1 = nn.GroupNorm(16, out_filters) self.bn2 = nn.GroupNorm(16, out_filters) self.bn3 = nn.GroupNorm(16, out_filters) else : self.trans_bn = nn.GroupNorm(32, in_filters) self.bn1 = nn.GroupNorm(32, out_filters) self.bn2 = nn.GroupNorm(32, out_filters) self.bn3 = nn.GroupNorm(32, out_filters) self.trans_dilao = conv3x3x3(in_filters, in_filters) self.trans_act = nn.LeakyReLU() self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) self.act1 = nn.LeakyReLU() self.conv2 = conv3x1x3(out_filters, out_filters) self.act2 = nn.LeakyReLU() self.conv3 = conv3x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() if height_pooling : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1) else : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1) def forward(self, x, residual, t): upA = self.trans_dilao(x) upA = self.trans_act(upA) t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) upA = self.trans_bn(upA) * (1 + scale) + shift ## upsample upA = self.up_subm(upA) upA += residual upE = self.conv1(upA) upE = self.act1(upE) upE = self.bn1(upE) upE = self.conv2(upE) upE = self.act2(upE) upE = self.bn2(upE) upE = self.conv3(upE) upE = self.act3(upE) upE = self.bn3(upE) return upE def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding class Denoise(nn.Module): def __init__(self, args, num_class = 11, init_size=32, discrete=True): super(Denoise, self).__init__() self.args = args self.discrete = discrete self.num_class = num_class self.init_size = init_size self.time_size = self.init_size*4 self.time_embed = nn.Sequential( nn.Linear(init_size, self.time_size), nn.SiLU(), nn.Linear(self.time_size, self.time_size), ) self.embedding = nn.Embedding(self.num_class, init_size) self.conv_in = nn.Conv3d(init_size+1, init_size, kernel_size=1, stride=1) self.A = Asymmetric_Residual_Block(init_size, init_size, time_filters=init_size*4) self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4) self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True, time_filters=init_size*4) self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4) self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False, time_filters=init_size*4) self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4) self.attention = Attention(16 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4) self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4) self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=init_size*4) self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4) self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4) self.DDCM = DDCM(2 * init_size, 2 * init_size) self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True) def forward(self, x, x_cond, t): x = self.embedding(x) x = x.permute(0, 4, 1, 2, 3) x_cond = x_cond.unsqueeze(1) x = torch.cat([x, x_cond], dim=1) x = self.conv_in(x) t = self.time_embed(timestep_embedding(t, self.init_size)) x = self.A(x, t) down1c, down1b = self.downBlock1(x, t) down2c, down2b = self.downBlock2(down1c, t) down3c, down3b = self.downBlock3(down2c, t) down4c, down4b = self.downBlock4(down3c, t) down4c = self.midBlock1(down4c, t) down4c = self.attention(down4c) down4c = self.midBlock2(down4c, t) up4 = self.upBlock4(down4c, down4b, t) up3 = self.upBlock3(up4, down3b, t) up2 = self.upBlock2(up3, down2b, t) up1 = self.upBlock1(up2, down1b, t) up0 = self.DDCM(up1) up = torch.cat((up1, up0), 1) logits = self.logits(up) return logits ================================================ FILE: layers/Voxel_Level/gen_denoise.py ================================================ import math from mimetypes import init import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from einops import rearrange, reduce, repeat from torch import nn, einsum def conv3x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x3x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False) def conv1x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False) def conv1x3x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False) def conv3x1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False) def conv3x1x3(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride) class Asymmetric_Residual_Block(nn.Module): def __init__(self, in_filters, out_filters, time_filters=128): super(Asymmetric_Residual_Block, self).__init__() if in_filters < 32 : n_ng = in_filters else : n_ng =32 self.GroupNorm = nn.GroupNorm(n_ng, in_filters) self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) if out_filters < 32 : n_ng = out_filters else : n_ng =32 self.bn0 = nn.GroupNorm(n_ng, out_filters) self.act1 = nn.LeakyReLU() self.conv1_2 = conv3x1x3(out_filters, out_filters) self.bn0_2 = nn.GroupNorm(n_ng, out_filters) self.act1_2 = nn.LeakyReLU() self.conv2 = conv3x1x3(in_filters, out_filters) self.act2 = nn.LeakyReLU() self.bn1 = nn.GroupNorm(n_ng, out_filters) self.conv3 = conv1x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() self.bn2 = nn.GroupNorm(n_ng, out_filters) def forward(self, x, t): t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) x = self.GroupNorm(x) * (1 + scale) + shift shortcut = self.conv1(x) shortcut = self.act1(shortcut) shortcut = self.bn0(shortcut) shortcut = self.conv1_2(shortcut) shortcut = self.act1_2(shortcut) shortcut = self.bn0_2(shortcut) resA = self.conv2(x) resA = self.act2(resA) resA = self.bn1(resA) resA = self.conv3(resA) resA = self.act3(resA) resA = self.bn2(resA) resA += shortcut return resA class DDCM(nn.Module): def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1): super(DDCM, self).__init__() self.conv1 = conv3x1x1(in_filters, out_filters) if out_filters < 32 : n_ng = out_filters else : n_ng =32 self.bn0 = nn.GroupNorm(n_ng, out_filters) self.act1 = nn.Sigmoid() self.conv1_2 = conv1x3x1(in_filters, out_filters) self.bn0_2 = nn.GroupNorm(n_ng, out_filters) self.act1_2 = nn.Sigmoid() self.conv1_3 = conv1x1x3(in_filters, out_filters) self.bn0_3 = nn.GroupNorm(n_ng, out_filters) self.act1_3 = nn.Sigmoid() def forward(self, x): shortcut = self.conv1(x) shortcut = self.bn0(shortcut) shortcut = self.act1(shortcut) shortcut2 = self.conv1_2(x) shortcut2 = self.bn0_2(shortcut2) shortcut2 = self.act1_2(shortcut2) shortcut3 = self.conv1_3(x) shortcut3 = self.bn0_3(shortcut3) shortcut3 = self.act1_3(shortcut3) shortcut = shortcut + shortcut2 + shortcut3 shortcut = shortcut * x return shortcut def l2norm(t): return F.normalize(t, dim = -1) class Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_qkv = conv1x1(dim, dim*3, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x): b, c, h, w, Z = x.shape qkv = self.to_qkv(x).chunk(3, dim = 1) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class Cross_Attention(nn.Module): def __init__(self, dim, heads = 4, scale = 10): super().__init__() self.scale = scale self.heads = heads self.to_q = conv1x1(dim, dim, stride=1) self.to_k = conv1x1(dim, dim, stride=1) self.to_v = conv1x1(dim, dim, stride=1) self.to_out = conv1x1(dim, dim, stride=1) def forward(self, x, cond_x): b, c, h, w, Z = x.shape q = self.to_q(x) k = self.to_k(cond_x) v = self.to_v(cond_x) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v)) q, k = map(l2norm, (q, k)) sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale attn = sim.softmax(dim = -1) out = einsum('b h i j, b h d j -> b h i d', attn, v) out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z) return self.to_out(out) class DownBlock(nn.Module): def __init__(self, in_filters, out_filters, time_filters, kernel_size=(3, 3, 3), stride=1, pooling=True, drop_out=True, height_pooling=False): super(DownBlock, self).__init__() self.pooling = pooling self.drop_out = drop_out self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters) if pooling: if height_pooling: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2, padding=1, bias=False) else: self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1), padding=1, bias=False) def forward(self, x, t): resA = self.residual_block(x, t) if self.pooling: resB = self.pool(resA) return resB, resA else: return resA class UpBlock(nn.Module): def __init__(self, in_filters, out_filters, height_pooling, time_filters): super(UpBlock, self).__init__() # self.drop_out = drop_out self.trans_dilao = conv3x3x3(in_filters, in_filters) self.trans_act = nn.LeakyReLU() if in_filters < 32 : n_ng = out_filters else : n_ng =32 self.trans_bn = nn.GroupNorm(n_ng, in_filters) self.time_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_filters, in_filters*2) ) self.conv1 = conv1x3x3(in_filters, out_filters) self.act1 = nn.LeakyReLU() if out_filters < 32 : n_ng = out_filters else :n_ng = 32 self.bn1 = nn.GroupNorm(n_ng, out_filters) self.conv2 = conv3x1x3(out_filters, out_filters) self.act2 = nn.LeakyReLU() self.bn2 = nn.GroupNorm(n_ng, out_filters) self.conv3 = conv3x3x3(out_filters, out_filters) self.act3 = nn.LeakyReLU() self.bn3 = nn.GroupNorm(n_ng, out_filters) if height_pooling : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1) else : self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1) def forward(self, x, residual, t): upA = self.trans_dilao(x) upA = self.trans_act(upA) t = self.time_layers(t) while len(t.shape) < len(x.shape): t = t[..., None] scale, shift = torch.chunk(t, 2, dim=1) upA = self.trans_bn(upA) * (1 + scale) + shift ## upsample upA = self.up_subm(upA) upA += residual upE = self.conv1(upA) upE = self.act1(upE) upE = self.bn1(upE) upE = self.conv2(upE) upE = self.act2(upE) upE = self.bn2(upE) upE = self.conv3(upE) upE = self.act3(upE) upE = self.bn3(upE) return upE def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding class Denoise(nn.Module): def __init__(self, args, num_class = 11, init_size=32, discrete=True): super(Denoise, self).__init__() self.args = args self.discrete = discrete self.num_class = num_class self.init_size = init_size self.time_size = init_size*4 self.time_embed = nn.Sequential( nn.Linear(self.init_size, self.time_size), nn.SiLU(), nn.Linear(self.time_size, self.time_size), ) self.embedding = nn.Embedding(self.num_class, self.init_size) self.conv_in = nn.Conv3d(self.init_size, self.init_size, kernel_size=1, stride=1) self.A = Asymmetric_Residual_Block(self.init_size, self.init_size, self.time_size) self.downBlock1 = DownBlock(init_size, 2 * init_size, self.time_size, height_pooling=True) self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, self.time_size, height_pooling=True) self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, self.time_size, height_pooling=False) self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, self.time_size, height_pooling=False) self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size) self.attention = Attention(16 * init_size, 32) self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size) self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=self.time_size) self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=self.time_size) self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size) self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size) self.DDCM = DDCM(2 * init_size, 2 * init_size) self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True) def forward(self, x, t): x = self.embedding(x) x = x.permute(0, 4, 1, 2, 3) x = self.conv_in(x) t = self.time_embed(timestep_embedding(t, self.init_size)) x = self.A(x, t) down1c, down1b = self.downBlock1(x, t) down2c, down2b = self.downBlock2(down1c, t) down3c, down3b = self.downBlock3(down2c, t) down4c, down4b = self.downBlock4(down3c, t) down4c = self.midBlock1(down4c, t) down4c = self.attention(down4c) down4c = self.midBlock2(down4c, t) up4 = self.upBlock4(down4c, down4b, t) up3 = self.upBlock3(up4, down3b, t) up2 = self.upBlock2(up3, down2b, t) up1 = self.upBlock1(up2, down1b, t) up0 = self.DDCM(up1) up = torch.cat((up1, up0), 1) logits = self.logits(up) return logits ================================================ FILE: layers/__init__.py ================================================ ================================================ FILE: requirements.txt ================================================ numpy torch scipy scikit-learn matplotlib tqdm open3d pyyaml prettytable tensorboard numba einops ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name="scene_scale_diffusion", version="0.1", author="Lee Jumin, Im Woobin, Lee Sebin, Yoon Sung-Eui", author_email="", description="Experiments in PyTorch", long_description="", packages=setuptools.find_packages(), classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], ) ================================================ FILE: simple_visualize.py ================================================ import os import numpy as np import open3d as o3d import argparse import yaml def load_config(yaml_path): with open(yaml_path, 'r') as f: config = yaml.safe_load(f) return config["learning_map"], config["remap_color_map"] def load_pointcloud(filepath, learning_map, color_map): data = np.loadtxt(filepath, delimiter=' ') if data.shape[1] < 4: raise ValueError(f"Expected at least 4 columns (label + x y z), got shape {data.shape}") raw_labels = data[:, 0].astype(int) points = data[:, 1:4] # Map raw labels → remapped labels → colors remapped_labels = np.array([learning_map.get(int(l), 0) for l in raw_labels]) colors = np.array([color_map.get(int(l), [255, 255, 255]) for l in remapped_labels]) / 255.0 pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) pcd.colors = o3d.utility.Vector3dVector(colors) return pcd def main(): parser = argparse.ArgumentParser() parser.add_argument('--file', default='result_for_l_gen/Completion/result_0.txt', help='Path to the point cloud .txt file') parser.add_argument('--config', default='datasets/carla.yaml', help='Path to Carla YAML config file') args = parser.parse_args() if not os.path.exists(args.file): raise FileNotFoundError(f"Point cloud file not found: {args.file}") if not os.path.exists(args.config): raise FileNotFoundError(f"YAML config file not found: {args.config}") learning_map, color_map = load_config(args.config) pcd = load_pointcloud(args.file, learning_map, color_map) o3d.visualization.draw([pcd]) if __name__ == "__main__": main() ================================================ FILE: train.py ================================================ from dataclasses import astuple import torch import argparse import numpy as np import os import pickle import torch import torch.nn.functional as F import yaml from prettytable import PrettyTable from torch.utils.tensorboard import SummaryWriter from utils.tables import * from utils.dicts import clean_dict from utils.loss import lovasz_softmax class Experiment(object): no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers'] def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch, train_loader, eval_loader, test_loader, train_sampler, log_path, eval_every, check_every): # Objects self.model = model self.loss_fun = torch.nn.CrossEntropyLoss(ignore_index=0) self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch # Paths self.log_path = log_path if args.dataset =='carla': config_file = os.path.join('./datasets/carla.yaml') carla_config = yaml.safe_load(open(config_file, 'r')) self.color_map = carla_config["remap_color_map"] self.remap = None LABEL_TO_NAMES = carla_config["label_to_names"] self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values())) # Intervals self.eval_every, self.check_every = eval_every, check_every # Initialize self.current_epoch = 0 self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {} self.eval_epochs = [] self.completion_epochs = [] # Store data loaders self.train_loader, self.eval_loader, self.test_loader, self.train_sampler = train_loader, eval_loader, test_loader, train_sampler # Store args create_folders(args) save_args(args) self.args = args # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) def run(self, epochs): if self.args.resume: self.resume() for epoch in range(self.current_epoch, epochs): # Train train_dict = self.train_fn(epoch) self.log_metrics(train_dict, self.train_metrics) # Checkpoint self.current_epoch += 1 if (epoch+1) % self.check_every == 0: self.checkpoint_save(epoch) # Eval if (epoch+1) % self.eval_every == 0: eval_dict = self.eval_fn(epoch) self.log_metrics(eval_dict, self.eval_metrics) self.eval_epochs.append(epoch) else: eval_dict = None if (epoch+1) % self.args.completion_epoch == 0: ssc_dict, miou, seg_dict, seg_miou = self.sample() self.log_metrics(ssc_dict, self.ssc_metrics) self.log_metrics(seg_dict, self.ssc_metrics) self.completion_epochs.append(epoch) else : ssc_dict, seg_dict = None, None # Log #self.save_metrics() if self.args.log_tb: for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1) if ssc_dict: for metric_name, metric_value in ssc_dict.items(): self.writer.add_scalar('SSC/{}'.format(metric_name), metric_value, global_step=epoch+1) self.writer.add_text("SSC_mIoU", get_miou_table(self.args, self.label_to_names, miou).get_html_string(), global_step=epoch+1) for metric_name, metric_value in seg_dict.items(): self.writer.add_scalar('Seg/{}'.format(metric_name), metric_value, global_step=epoch+1) self.writer.add_text("Seg_mIoU", get_miou_table(self.args, self.label_to_names, seg_miou).get_html_string(), global_step=epoch+1) def train_fn(self, epoch): self.model.train() loss_sum = 0.0 loss_count = 0 if self.args.distribution : self.train_sampler.set_epoch(epoch) for voxel_input, output, counts in self.train_loader: self.optimizer.zero_grad() voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32) output = torch.from_numpy(np.asarray(output)).long().cuda() if self.args.distribution: loss = self.model.module(output, voxel_input) else : loss = self.model(output, voxel_input) loss.backward() if self.args.clip_value: torch.nn.utils.clip_grad_value_(self.model.parameters(), self.args.clip_value) if self.args.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_norm) self.optimizer.step() if self.scheduler_iter: self.scheduler_iter.step() loss_sum += loss.detach().cpu().item() * len(output) loss_count += len(output) print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\r') print('') if self.scheduler_epoch: self.scheduler_epoch.step() return {'loss': loss_sum/loss_count} def eval_fn(self, epoch): self.model.eval() with torch.no_grad(): loss_sum = 0.0 loss_count = 0 for voxel_input, output, counts in self.eval_loader: voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32) output = torch.from_numpy(np.asarray(output)).long().cuda() if self.args.distribution: loss = self.model.module(output, voxel_input) else : loss = self.model(output, voxel_input) loss_sum += loss.detach().cpu().item() * len(output) loss_count += len(output) print('Train evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r') print('') return {'loss': loss_sum/loss_count} def sample(self): self.model.eval() with torch.no_grad(): TP, FP, TN, FN, num_correct, num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 s_TP, s_FP, s_TN, s_FN, s_num_correct, s_num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 all_intersections, all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6 s_all_intersections, s_all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6 if self.args.dataset == 'carla': dataloader = self.test_loader else : dataloader = self.eval_loader for iterate, (voxel_input, output, counts) in enumerate(dataloader): if len(voxel_input) == self.args.batch_size : voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32) output = torch.from_numpy(np.asarray(output)).long().cuda() invalid = torch.from_numpy(np.asarray(counts)).cuda() if self.args.mode == 'l_vae': if self.args.distribution: recons = self.model.module.sample(output) else : recons = self.model.sample(output) else : if self.args.distribution: recons = self.model.module.sample(voxel_input) else : recons = self.model.sample(voxel_input) visualization(self.args, recons, voxel_input, output, invalid, iteration = iterate) correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union = get_result(self.args, invalid, output, recons) all_intersections += intersection all_unions += union num_correct += correct num_total += total TP += pred_TP FP += pred_FP TN += pred_TN FN += pred_FN s_correct, s_total, s_pred_TP, s_pred_FP, s_pred_TN, s_pred_FN, s_intersection, s_union = get_result(self.args, voxel_input, output, recons, SSC=False) s_all_intersections += s_intersection s_all_unions += s_union s_num_correct += s_correct s_num_total += s_total s_TP += s_pred_TP s_FP += s_pred_FP s_TN += s_pred_TN s_FN += s_pred_FN iou, miou = print_result(self.args, self.label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN) s_iou, seg_miou = print_result(self.args, self.label_to_names, s_num_correct, s_num_total, s_all_intersections, s_all_unions, s_TP, s_FP, s_FN, SSC=False) return {"IoU" : iou, "mIoU": np.mean(miou)*100 }, miou, {"IoU" : s_iou, "mIoU": np.mean(seg_miou)*100 }, seg_miou def resume(self): self.checkpoint_load(self.args.resume_path) for epoch in range(self.current_epoch): train_dict = {} for metric_name, metric_values in self.train_metrics.items(): train_dict[metric_name] = metric_values[epoch] if epoch in self.eval_epochs: eval_dict = {} for metric_name, metric_values in self.eval_metrics.items(): eval_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)] else: eval_dict = None if epoch in self.completion_epochs: sample_dict = {} for metric_name, metric_values in self.eval_metrics.items(): sample_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)] else: sample_dict = None for metric_name, metric_value in train_dict.items(): self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1) if eval_dict: for metric_name, metric_value in eval_dict.items(): self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1) if sample_dict: for metric_name, metric_value in sample_dict.items(): self.writer.add_scalar('sample/{}'.format(metric_name), metric_value, global_step=epoch+1) def log_metrics(self, dict, type): if len(type)==0: for metric_name, metric_value in dict.items(): type[metric_name] = [metric_value] else: for metric_name, metric_value in dict.items(): type[metric_name].append(metric_value) def save_metrics(self): # Save metrics with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f: pickle.dump(self.train_metrics, f) with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f: pickle.dump(self.eval_metrics, f) # Save metrics table metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2))) with open(os.path.join(self.log_path,'metrics_train.txt'), "w") as f: f.write(str(metric_table)) metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs]) with open(os.path.join(self.log_path,'metrics_eval.txt'), "w") as f: f.write(str(metric_table)) def checkpoint_save(self, epoch): if self.args.distribution: checkpoint = {'current_epoch': self.current_epoch, 'train_metrics': self.train_metrics, 'eval_metrics': self.eval_metrics, 'eval_epochs': self.eval_epochs, 'optimizer': self.optimizer.state_dict(), 'model': self.model.module.state_dict(), 'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None, 'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,} else : checkpoint = {'current_epoch': self.current_epoch, 'train_metrics': self.train_metrics, 'eval_metrics': self.eval_metrics, 'eval_epochs': self.eval_epochs, 'optimizer': self.optimizer.state_dict(), 'model': self.model.state_dict(), 'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None, 'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,} epoch_name = 'epoch{}.tar'.format(epoch) torch.save(checkpoint, os.path.join(self.log_path, epoch_name)) def checkpoint_load(self, resume_path): checkpoint = torch.load(resume_path) if self.args.distribution: self.model.module.load_state_dict(checkpoint['model']) else : self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter']) if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch']) self.current_epoch = checkpoint['current_epoch'] self.train_metrics = checkpoint['train_metrics'] self.eval_metrics = checkpoint['eval_metrics'] self.eval_epochs = checkpoint['eval_epochs'] ================================================ FILE: utils/cuda.py ================================================ import os import os import torch from torch import distributed as dist from torch import multiprocessing as mp import utils.dicts as dist_fn def find_free_port(): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() return port def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): world_size = n_machine * n_gpu_per_machine if world_size > 1: # if "OMP_NUM_THREADS" not in os.environ: # os.environ["OMP_NUM_THREADS"] = "1" if dist_url == "auto": if n_machine != 1: raise ValueError('dist_url="auto" not supported in multi-machine jobs') port = find_free_port() dist_url = f"tcp://127.0.0.1:{port}" if n_machine > 1 and dist_url.startswith("file://"): raise ValueError( "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" ) mp.spawn( distributed_worker, nprocs=n_gpu_per_machine, args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), daemon=False, ) else: local_rank = 0 fn(local_rank, *args) def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args): if not torch.cuda.is_available(): raise OSError("CUDA is not available. Please check your environments") global_rank = machine_rank * n_gpu_per_machine + local_rank try: dist.init_process_group( backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank, ) except Exception: raise OSError("failed to initialize NCCL groups") dist_fn.synchronize() if n_gpu_per_machine > torch.cuda.device_count(): raise ValueError( f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" ) torch.cuda.set_device(local_rank) if dist_fn.LOCAL_PROCESS_GROUP is not None: raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") n_machine = world_size // n_gpu_per_machine for i in range(n_machine): ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) pg = dist.new_group(ranks_on_i) if i == machine_rank: dist_fn.LOCAL_PROCESS_GROUP = pg fn(local_rank, *args) def set_cuda_vd(gpu_ids, verbose=True): os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(id) for id in gpu_ids) if verbose: print("CUDA_VISIBLE_DEVICES = {}",format(os.environ["CUDA_VISIBLE_DEVICES"])) ================================================ FILE: utils/dicts.py ================================================ import copy import math import pickle import torch from torch import distributed as dist from torch.utils import data LOCAL_PROCESS_GROUP = None def is_primary(): return get_rank() == 0 def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def get_local_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 if LOCAL_PROCESS_GROUP is None: raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") return dist.get_rank(group=LOCAL_PROCESS_GROUP) def synchronize(): if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def is_distributed(): raise RuntimeError('Please debug this function!') return get_world_size() > 1 def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False): world_size = get_world_size() if world_size == 1: return tensor dist.all_reduce(tensor, op=op, async_op=async_op) return tensor def all_gather(data): world_size = get_world_size() if world_size == 1: return [data] buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") local_size = torch.IntTensor([tensor.numel()]).to("cuda") size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") tensor = torch.cat((tensor, padding), 0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): keys = [] values = [] for k in sorted(input_dict.keys()): keys.append(k) values.append(input_dict[k]) values = torch.stack(values, 0) dist.reduce(values, dst=0) if dist.get_rank() == 0 and average: values /= world_size reduced_dict = {k: v for k, v in zip(keys, values)} return reduced_dict def data_sampler(dataset, shuffle, distributed): if distributed: return data.distributed.DistributedSampler(dataset, shuffle=shuffle) if shuffle: return data.RandomSampler(dataset) else: return data.SequentialSampler(dataset) def clean_dict(d, keys): d2 = copy.deepcopy(d) for key in keys: if key in d2: del d2[key] return d2 ================================================ FILE: utils/intermediate_vis.py ================================================ from dataclasses import astuple import torch import argparse import numpy as np import os import pickle import torch import torch.nn.functional as F import yaml from prettytable import PrettyTable from torch.utils.tensorboard import SummaryWriter from utils.tables import * from utils.dicts import clean_dict from utils.loss import lovasz_softmax class Vis_iter(object): no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers'] def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader,log_path): # Objects self.model = model self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch # Paths self.log_path = log_path if args.dataset =='kitti': config_file = os.path.join('/home/jumin/multinomial_diffusion/datasets/semantic_kitti.yaml') kitti_config = yaml.safe_load(open(config_file, 'r')) self.remap = kitti_config['learning_map_inv'] self.color_map = kitti_config["color_map"] label = kitti_config['labels'] map_index = np.asarray([self.remap[i] for i in range(20)]) self.label_to_names = np.asarray([label[map_i] for map_i in map_index]) elif args.dataset =='carla': base_dir = os.path.dirname(__file__) config_file = os.path.join(base_dir, '../datasets/carla.yaml') carla_config = yaml.safe_load(open(config_file, 'r')) self.color_map = carla_config["remap_color_map"] self.remap = None LABEL_TO_NAMES = carla_config["label_to_names"] self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values())) # Initialize self.current_epoch = 0 self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {} self.eval_epochs = [] self.completion_epochs = [] # Store data loaders self.test_loader = test_loader # Store args create_folders(args) save_args(args) self.args = args # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) def run(self, epochs): self.checkpoint_load(self.args.resume_path) for epoch in range(self.current_epoch, epochs): self.sample() def sample(self): self.model.eval() with torch.no_grad(): for iterate, (voxel_input, output, counts) in enumerate(self.test_loader): voxel_input = torch.from_numpy(np.asarray(voxel_input)).squeeze(1).cuda() output = torch.from_numpy(np.asarray(output)).long().cuda() _, intermediate = self.model.module.sample(voxel_input, intermediate=True) inter_vis(self.args, intermediate) break def checkpoint_load(self, resume_path): checkpoint = torch.load(resume_path) if self.args.distribution: self.model.module.load_state_dict(checkpoint['model']) else : self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter']) if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch']) self.current_epoch = checkpoint['current_epoch'] self.train_metrics = checkpoint['train_metrics'] self.eval_metrics = checkpoint['eval_metrics'] self.eval_epochs = checkpoint['eval_epochs'] ================================================ FILE: utils/loss.py ================================================ import math import torch from torch.autograd import Variable import torch.nn.functional as F import numpy as np try: from itertools import ifilterfalse except ImportError: # py3k from itertools import filterfalse as ifilterfalse # -*- coding:utf-8 -*- # author: Xinge def dice_coef(y_true, y_pred, smooth=1e-6): y_true_f = y_true.view(-1) y_pred_f = y_pred.view(-1) intersection = (y_true_f * y_pred_f).sum() return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth) def dice_coef_multilabel(y_true, y_pred, numLabels=11): dice=0 for index in range(1, numLabels): dice += dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:]) return (numLabels-1) - dice """ Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) """ def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard # --------------------------- MULTICLASS LOSSES --------------------------- def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): """ Multi-class Lovasz-Softmax loss probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. per_image: compute the loss per image instead of per batch ignore: void class labels """ if per_image: loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) for prob, lab in zip(probas, labels)) else: loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) return loss def lovasz_softmax_flat(probas, labels, classes='present'): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. """ if probas.numel() == 0: # only void pixels, the gradients should be 0 return probas * 0. C = probas.size(1) losses = [] class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c if (classes is 'present' and fg.sum() == 0): continue if C == 1: if len(classes) > 1: raise ValueError('Sigmoid output possible only with 1 class') class_pred = probas[:, 0] else: class_pred = probas[:, c] errors = (Variable(fg) - class_pred).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) return mean(losses) def flatten_probas(probas, labels, ignore=None): """ Flattens predictions in the batch """ if probas.dim() == 3: # assumes output of a sigmoid layer B, H, W = probas.size() probas = probas.view(B, 1, H, W) elif probas.dim() == 5: #3D segmentation B, C, L, H, W = probas.size() probas = probas.contiguous().view(B, C, L, H*W) B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels # --------------------------- HELPER FUNCTIONS --------------------------- def isnan(x): return x != x def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n ================================================ FILE: utils/multistep.py ================================================ import torch.optim as optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import _LRScheduler class LinearWarmupScheduler(_LRScheduler): """ Linearly warm-up (increasing) learning rate, starting from zero. Args: optimizer (Optimizer): Wrapped optimizer. total_epoch: target learning rate is reached at total_epoch. """ def __init__(self, optimizer, total_epoch, last_epoch=-1): self.total_epoch = total_epoch super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch) def get_lr(self): return [base_lr * min(1, (self.last_epoch / self.total_epoch)) for base_lr in self.base_lrs] optim_choices = {'sgd', 'adam', 'adamax'} def get_optim(args, model): assert args.optimizer in optim_choices if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) elif args.optimizer == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) if args.warmup is not None: scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup) else: scheduler_iter = None if len(args.milestones)>0: scheduler_epoch = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) else: scheduler_epoch = None return optimizer, scheduler_iter, scheduler_epoch ================================================ FILE: utils/tables.py ================================================ from prettytable import PrettyTable import torch import os import pickle import numpy as np import torch.nn.functional as F import open3d as o3d def get_args_table(args_dict): table = PrettyTable(['Arg', 'Value']) for arg, val in args_dict.items(): table.add_row([arg, val]) return table def get_miou_table(args, label_to_names, miou): table = PrettyTable(['Label', 'mIoU']) for i in range(args.num_classes): table.add_row([label_to_names[i], 100 * miou[i]]) return table def get_metric_table(metric_dict, epochs): table = PrettyTable() table.add_column('Epoch', epochs) if len(metric_dict)>0: for metric_name, metric_values in metric_dict.items(): table.add_column(metric_name, metric_values) return table def create_folders(args): # Create log folder os.makedirs(args.log_path, exist_ok=True) os.makedirs(args.log_path+'/Completion', exist_ok=True) os.makedirs(args.log_path+'/Input', exist_ok=True) os.makedirs(args.log_path+'/Output', exist_ok=True) os.makedirs(args.log_path+'/Invalid', exist_ok=True) print("Storing logs in:", args.log_path) def inter_vis(args, recons): for r in range(len(recons)): for batch, samples_i in enumerate(recons[r]): color_index = [] for i in range(1, args.num_classes): index = torch.nonzero(samples_i == i ,as_tuple=False) color_index.append(F.pad(index,(1,0),'constant',value = i)) colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy() np.savetxt('/home/jumin/multinomial_diffusion/Result/Condition/Completion/iteration/batch{}_{}.txt'.format(batch, r), colors_indexs) def visualization(args, recons, input_data, output, invalid, iteration): for batch, (samples_i, input_i, output_i, invalid_i) in enumerate(zip(recons, input_data, output, invalid)): color_index = [] output_index = [] input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy() if args.dataset =='carla': invalid_points = torch.nonzero(invalid_i == 0, as_tuple=False).cpu().numpy() elif args.dataset =='kitti': invalid_points = torch.nonzero(invalid_i == 1, as_tuple=False).cpu().numpy() for i in range(1, args.num_classes): index = torch.nonzero(samples_i == i ,as_tuple=False) out_color = torch.nonzero(output_i == i, as_tuple=False) color_index.append(F.pad(index,(1,0),'constant',value = i)) output_index.append(F.pad(out_color,(1,0),'constant',value=i)) colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy() out_indexs = torch.cat(output_index, dim = 0).cpu().numpy() np.savetxt(args.log_path+'/Completion/result_{}.txt'.format((iteration * args.batch_size) + batch), colors_indexs) '''np.savetxt(args.log_path+'/Input/input_{}.txt'.format((iteration * args.batch_size) + batch), input_points) np.savetxt(args.log_path+'/Invalid/invalid_{}.txt'.format((iteration * args.batch_size) + batch), invalid_points) np.savetxt(args.log_path+'/Output/gt_{}.txt'.format((iteration * args.batch_size) + batch), out_indexs)''' def completion_vis(args, input_p, recons): for batch, (recon_i, input_i) in enumerate(zip(recons, input_p)): recon_points = torch.nonzero(recon_i == 1, as_tuple=False).cpu().numpy() input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy() np.savetxt(args.log_path+'/Completion/completion_{}.txt'.format(batch), recon_points) np.savetxt(args.log_path+'/Input/input_{}.txt'.format(batch), input_points) def iou_one_frame(pred, target, n_classes=23): pred = pred.view(-1).detach().cpu().numpy() target = target.view(-1).detach().cpu().numpy() intersection = np.zeros(n_classes) union = np.zeros(n_classes) for cls in range(n_classes): intersection[cls] = np.sum((pred == cls) & (target == cls)) union[cls] = np.sum((pred == cls) | (target == cls)) return intersection, union def get_result(args, for_mask, output, preds, SSC=True): for_mask = for_mask.contiguous().view(-1) output = output.contiguous().view(-1) preds = preds.contiguous().view(-1) if SSC : if args.dataset == 'kitti': mask = for_mask == 0 elif args.dataset== 'carla': mask = for_mask > 0 else : mask = for_mask == 1 output_masked = output[mask] iou_output_masked = output_masked.cpu().numpy() iou_output_masked[iou_output_masked != 0] = 1 preds_masked = preds[mask] iou_preds_masked = preds_masked.cpu().numpy() iou_preds_masked[iou_preds_masked != 0] = 1 # I, U for a frame correct = np.sum(output_masked.cpu().numpy() == preds_masked.cpu().numpy()) total = preds_masked.shape[0] pred_TP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 1)) pred_FP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 0)) pred_TN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 0)) pred_FN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 1)) intersection, union = iou_one_frame(preds_masked, output_masked, n_classes=args.num_classes) return correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union def save_args(args): # Save args with open(os.path.join(args.log_path, 'args.pickle'), "wb") as f: pickle.dump(args, f) # Save args table args_table = get_args_table(vars(args)) with open(os.path.join(args.log_path,'args_table.txt'), "w") as f: f.write(str(args_table)) def print_completion(num_correct, num_total, TP, FP, FN): print("\n=========================================\n") accuracy = num_correct/num_total print("\nAccuracy : ", accuracy) precision = 100 * TP / (TP + FP) recall = 100 * TP / (TP + FN) iou = 100 * TP / (TP + FP + FN) print("\nCompleteness") print("precision:", precision) print("recall:", recall) print("iou:", iou) print("\n=========================================\n") return iou def print_result(args, label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN, SSC=True): if SSC : print("\n========== Semantic Scene Completion =============\n") else : print("\n============ Semantic Segmentation ===============\n") accuracy = num_correct/num_total print("\nAccuracy : ", accuracy) precision = 100 * TP / (TP + FP) recall = 100 * TP / (TP + FN) iou = 100 * TP / (TP + FP + FN) print("\nCompleteness") print("precision:", precision) print("recall:", recall) print("iou:", iou) print("\nSemantic IoU Per Class") miou = all_intersections / all_unions for i in range(args.num_classes): print(label_to_names[i], ':', 100 * miou[i]) print("\n====================================================\n") return iou, miou ================================================ FILE: visualization.py ================================================ import os import open3d as o3d import open3d.visualization.gui as gui import open3d.visualization.rendering as rendering import argparse import numpy as np import yaml import struct parser = argparse.ArgumentParser() parser.add_argument('--M', default='scene-scale-diffusion') # VQVAE, multinomial_diffusion parser.add_argument('--Driver', default='D') parser.add_argument('--frame', default='0') parser.add_argument('--file', default='result_') parser.add_argument('--folder', default='Completion') parser.add_argument('--model', default='image_init8_concat_att') parser.add_argument('--name', default='Semantic Scene Completion') parser.add_argument('--invalid', default = False) class SpheresApp: MENU_SCENE = 1 MENU_BEFORE = 2 MENU_QUIT = 3 def __init__(self, opt): self._id = 0 self.opt = opt self.window = gui.Application.instance.create_window("Semantic Scene Completion", 1500, 1000) self.scene = gui.SceneWidget() self.scene.scene = rendering.Open3DScene(self.window.renderer) self.scene.scene.set_background([1, 1, 1, 1]) self.scene.scene.scene.set_sun_light( [-0.577, 0.577, -0.577], # direction [1, 1, 1], # color 60000) # intensity self.scene.scene.scene.enable_sun_light(True) bbox = o3d.geometry.AxisAlignedBoundingBox([64, 64, -60], [64, 64, 60]) self.scene.setup_camera(60, bbox, [0, 0, 1]) self.window.add_child(self.scene) if gui.Application.instance.menubar is None: debug_menu = gui.Menu() debug_menu.add_item("Next Scene", SpheresApp.MENU_SCENE) debug_menu.add_separator() debug_menu.add_item("Before Scene", SpheresApp.MENU_BEFORE) debug_menu.add_separator() debug_menu.add_item("Quit", SpheresApp.MENU_QUIT) menu = gui.Menu() menu.add_menu("SSC", debug_menu) gui.Application.instance.menubar = menu # The menubar is global, but we need to connect the menu items to the # window, so that the window can call the appropriate function when the menu item is activated. self.window.set_on_menu_item_activated(SpheresApp.MENU_SCENE,self._on_menu_scene) self.window.set_on_menu_item_activated(SpheresApp.MENU_QUIT,self._on_menu_quit) self.window.set_on_menu_item_activated(SpheresApp.MENU_BEFORE,self._on_menu_before) def _on_menu_before(self): self._id -= 1 mat = rendering.MaterialRecord() mat.shader = "defaultLit" if self.opt.file == 'input_': points = get_input(self.opt) else : points, colors = get_voxel(self.opt) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) if (self.opt.file != 'input_'): pcd.colors = o3d.utility.Vector3dVector(colors/255) self.scene.scene.clear_geometry() voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1) self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat) print(self.opt.frame) self.opt.frame = str(int(self.opt.frame)-1) def _on_menu_quit(self): gui.Application.instance.quit() def _on_menu_scene(self): self._id += 1 mat = rendering.MaterialRecord() mat.shader = "defaultLit" if self.opt.file == 'input_': points = get_input(self.opt) else : points, colors = get_voxel(self.opt) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) if (self.opt.file != 'input_'): pcd.colors = o3d.utility.Vector3dVector(colors/255) self.scene.scene.clear_geometry() voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1) self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat) print(self.opt.frame) self.opt.frame = str(int(self.opt.frame)+1) def get_voxel(opt): if opt.invalid : invalid_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/Invalid/invalid_'+ opt.frame +'.txt' invalid_points = np.loadtxt(invalid_path, delimiter=' ') invalid_colors = np.full(len(invalid_points,), 0) point_cloud_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/' + opt.folder +'/'+ opt.file + opt.frame +'.txt' points_colors = np.loadtxt(point_cloud_path, delimiter=' ') points = points_colors[:, 1:] colors = points_colors[:, 0] points = np.concatenate((invalid_points, points), axis=0) colors = np.concatenate((invalid_colors, colors), axis=0) points, index = np.unique(points, return_index=True, axis=0) colors = colors[index, ...] else : point_cloud_path = 'C:/Users/jumin/Dataset/result_319_110.txt' points_colors = np.loadtxt(point_cloud_path, delimiter=' ') points = points_colors[:, 1:] colors = points_colors[:, 0] if opt.dataset == 'carla' : base_dir = os.path.dirname(__file__) config_file = os.path.join(base_dir, 'datasets/carla.yaml') config = yaml.safe_load(open(config_file, 'r')) color_map = config["remap_color_map"] color = np.asarray([color_map[c] for c in colors]) return points, color def get_input(opt): point_cloud_path=opt.Driver+':/'+opt.M+'/result/' + opt.model +'/Invalid/invalid_' + opt.frame +'.txt' points_colors = np.loadtxt(point_cloud_path, delimiter=' ') points = points_colors pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) return points def main(opt): gui.Application.instance.initialize() SpheresApp(opt) gui.Application.instance.run() if __name__ == "__main__": opt = parser.parse_args() main(opt)