Repository: zoomin-lee/scene-scale-diffusion
Branch: main
Commit: a23745f37edd
Files: 30
Total size: 163.3 KB
Directory structure:
gitextract_ghos05eb/
├── .gitignore
├── LICENSE.txt
├── README.md
├── SSC_train.py
├── __init__.py
├── datasets/
│ ├── carla.yaml
│ ├── carla_dataset.py
│ └── data.py
├── layers/
│ ├── Ablation/
│ │ └── wo_diffusion.py
│ ├── Latent_Level/
│ │ ├── stage1/
│ │ │ ├── model.py
│ │ │ ├── vector_quantizer.py
│ │ │ └── vqvae.py
│ │ └── stage2/
│ │ ├── Gen_diffusion.py
│ │ └── gen_denoise.py
│ ├── Voxel_Level/
│ │ ├── Con_Diffusion.py
│ │ ├── Gen_Diffusion.py
│ │ ├── denoise.py
│ │ └── gen_denoise.py
│ └── __init__.py
├── requirements.txt
├── setup.py
├── simple_visualize.py
├── train.py
├── utils/
│ ├── cuda.py
│ ├── dicts.py
│ ├── intermediate_vis.py
│ ├── loss.py
│ ├── multistep.py
│ └── tables.py
└── visualization.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
build/
dist/
*.egg-info/
.eggs/
# Virtual environment
venv/
env/
.venv/
.env/
# Jupyter Notebook checkpoints
.ipynb_checkpoints/
# PyInstaller
*.manifest
*.spec
# pytest
.cache/
.pytest_cache/
# mypy
.mypy_cache/
# coverage
htmlcov/
.coverage
.coverage.*
# logs and temporary files
*.log
*.tmp
*.bak
# IDEs and editors
.vscode/
.idea/
*.sublime-workspace
*.sublime-project
# OS files
.DS_Store
Thumbs.db
# dotenv / secrets
.env
.env.*
================================================
FILE: LICENSE.txt
================================================
MIT License
Copyright (c) 2023 jumin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data
📌[Paper](http://arxiv.org/abs/2301.00527)
<img src=https://user-images.githubusercontent.com/65997635/210452550-2c7c7c6d-7260-43ce-b4b6-18d3f15fccde.png width="480"
height="400">
Comparison of object-scale and scene scale generation (ours). Our result includes multiple objects in a generated scene,
while the object-scale generation crafts one object at a time. (a) is obtained by [Point-E](https://github.com/openai/point-e)
## Abstract
In this paper, we learn a diffusion model to generate 3D data on a scene-scale. Specifically, our model crafts a 3D scene consisting of multiple objects, while recent diffusion research has focused on a single object. To realize our goal, we represent a scene with discrete class labels, i.e., categorical distribution, to assign multiple objects into semantic categories. Thus, we extend discrete diffusion models to learn scene-scale categorical distributions. In addition, we validate that a latent diffusion model can reduce computation costs for training and deploying. To the best of our knowledge, our work is the first to apply discrete and latent diffusion for 3D categorical data on a scene-scale. We further propose to perform semantic scene completion (SSC) by learning a conditional distribution using our diffusion model, where the condition is a partial observation in a sparse point cloud. In experiments, we empirically show that our diffusion models not only generate reasonable scenes, but also perform the scene completion task better than a discriminative model.
## Instructions
### Dataset
: We use [CarlaSC](https://umich-curly.github.io/CarlaSC.github.io/download/) cartesian dataset.
### Training
: There are some argparse in 'SSC_train.py'.
python SSC_train.py
- For **multi-GPU** : --distribution True
- For **Discrete Diffusion Model** : --mode gen/con/vis
- For **Latent Diffusion Model** : --mode l_vae/l_gen --l_size 882/16162/32322 --init_size 32 --l_attention True --vq_size 100
Example for training l_gen mode
python SSC_train.py --mode l_gen --vq_size 100 --l_size 32322 --init_size 32 --l_attention True --log_path ./result --vqvae_path ./lst_stage.tar
### Visualization
: We save the result to a txt file using the `utils/table.py/visulization` function.
If you use open3d, you will be able to easily visualize it.
## Result
### 3D Scene Generation

### Semantic Scene Completion


## Acknowledgments
This project is based on the following codebase.
- [Multinomial Diffusion](https://github.com/ehoogeboom/multinomial_diffusion/tree/9d907a60536ad793efd6d2a6067b3c3d6ba9fce7)
- [MotionSC](https://github.com/UMich-CURLY/3DMapping)
- [Cylinder3D](https://github.com/xinge008/Cylinder3D)
================================================
FILE: SSC_train.py
================================================
import argparse
import os
import warnings
import time
import torch
from utils.intermediate_vis import Vis_iter
from datasets.data import *
from utils.cuda import launch
from utils.multistep import get_optim
from train import Experiment
from layers.Voxel_Level.Gen_Diffusion import Diffusion
from layers.Voxel_Level.Con_Diffusion import Con_Diffusion
from layers.Latent_Level.stage1.vqvae import vqvae
from layers.Latent_Level.stage2.Gen_diffusion import latent_diffusion
from layers.Ablation.wo_diffusion import wo_diff
# environment variables
NODE_RANK = os.environ['AZ_BATCHAI_TASK_INDEX'] if 'AZ_BATCHAI_TASK_INDEX' in os.environ else 0
NODE_RANK = int(NODE_RANK)
MASTER_ADDR, MASTER_PORT = os.environ['AZ_BATCH_MASTER_NODE'].split(':') if 'AZ_BATCH_MASTER_NODE' in os.environ else ("127.0.0.1", 29500)
MASTER_PORT = int(MASTER_PORT)
DIST_URL = 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT)
def get_args():
###########
## Setup ##
###########
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=None, help='GPU id to use. If given, only the specific gpu will be used, and ddp will be disabled')
parser.add_argument('--distribution', type=bool, default=True)
parser.add_argument('--num_node', type=int, default=1, help='number of nodes for distributed training')
parser.add_argument('--node_rank', type=int, default=0, help='node rank for distributed training')
parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:29500', help='url used to set up distributed training')
# Data params
parser.add_argument('--dataset', type=str, default='carla', choices='carla')
parser.add_argument('--dataset_dir', type=str, required=True, help='Path to the dataset directory')
# Train params
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', type=eval, default=False)
parser.add_argument('--augmentation', type=str, default=None)
# Experiemtn params
parser.add_argument('--clip_value', type=float, default=None)
parser.add_argument('--clip_norm', type=float, default=None)
parser.add_argument('--recon_loss', default=False)
parser.add_argument('--mode', default='wo_diff', choices='gen, con, vis, l_vae l_gen, wo_diff')
parser.add_argument('--l_size', default='32322', choices=['882', '16162', '32322'])
parser.add_argument('--init_size', type=int, default=8)
parser.add_argument('--l_attention', default=True)
parser.add_argument('--vq_size', type=int, default=50)
# Model params
parser.add_argument('--auxiliary_loss_weight', type=int, default=0.0005)
parser.add_argument('--diffusion_steps', type=int, default=100)
parser.add_argument('--diffusion_dim', type=int, default=32)
parser.add_argument('--dp_rate', type=float, default=0.)
# Optim params
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--warmup', type=int, default=None)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--momentum_sqr', type=float, default=0.999)
parser.add_argument('--milestones', type=eval, default=[])
parser.add_argument('--gamma', type=float, default=0.1)
# Train params
parser.add_argument('--epochs', type=int, default=5000)
parser.add_argument('--resume', type=str, default=False)
parser.add_argument('--resume_path', type=str, default='')
parser.add_argument('--vqvae_path', type=str, default='')
# Logging params
parser.add_argument('--eval_every', type=int, default=10)
parser.add_argument('--check_every', type=int, default=5)
parser.add_argument('--completion_epoch', type=int, default=20)
parser.add_argument('--log_tb', type=eval, default=True)
parser.add_argument('--log_home', type=str, default=None)
parser.add_argument('--log_path', type=str, default='')
args = parser.parse_args()
return args
def main():
print('start!')
args = get_args()
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely disable ddp.')
torch.cuda.set_device(args.gpu)
args.ngpus_per_node = 1
args.world_size = 1
else:
if args.num_node == 1:
args.dist_url == "auto"
else:
assert args.num_node > 1
args.ngpus_per_node = torch.cuda.device_count()
args.world_size = args.ngpus_per_node * args.num_node
launch(start, args.ngpus_per_node, args.num_node, args.node_rank, args.dist_url, args=(args,))
def start(local_rank, args):
args.local_rank = local_rank
args.global_rank = args.local_rank + args.node_rank * args.ngpus_per_node
args.distributed = args.world_size > 1
##################
## Specify data ##
##################
train_loader, eval_loader, test_loader, num_classes, comp_weights, seg_weights, train_sampler = get_data(args)
args.num_classes = num_classes
completion_criterion = torch.nn.CrossEntropyLoss(weight=comp_weights)
seg_criterion = torch.nn.CrossEntropyLoss(weight=seg_weights, ignore_index=0)
similarity_criterion = torch.nn.MSELoss()
#######################
## Without Diffusion ##
#######################
if args.mode == 'wo_diff':
model = wo_diff(args, completion_criterion).cuda()
if args.distribution :
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
########################
## Discrete Diffusion ##
########################
elif args.mode == 'gen':
model = Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
if args.distribution :
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
elif args.mode == 'con':
model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
if args.distribution :
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
######################
## Latent Diffusion ##
######################
elif args.mode == 'l_vae':
model = vqvae(args, completion_criterion).cuda()
if args.distribution:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
elif args.mode == 'l_gen':
Dense = vqvae(args, completion_criterion).cuda()
dense_check = torch.load(args.vqvae_path)
model = latent_diffusion(args, Dense, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
if args.distribution:
Dense = torch.nn.parallel.DistributedDataParallel(Dense, device_ids=[args.gpu], find_unused_parameters=False)
Dense.module.load_state_dict(dense_check['model'])
for p in Dense.module.parameters():
p.requires_grad = False
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
###################
## Visualization ##
###################
elif args.mode == 'vis':
model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
if args.distribution :
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
optimizer, scheduler_iter, scheduler_epoch = get_optim(args, model)
if args.mode == 'vis':
exp = Vis_iter(args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader, args.log_path)
else :
exp = Experiment(args, model, optimizer, scheduler_iter, scheduler_epoch,
train_loader, eval_loader, test_loader, train_sampler,
args.log_path, args.eval_every, args.check_every)
exp.run(epochs = args.epochs)
if __name__ == '__main__':
main()
================================================
FILE: __init__.py
================================================
================================================
FILE: datasets/carla.yaml
================================================
color_map :
0 : [255, 255, 255] # None
1 : [70, 70, 70] # Building
2 : [100, 40, 40] # Fences
3 : [55, 90, 80] # Other
4 : [255, 255, 0 ] # Pedestrian
5 : [153, 153, 153] # Pole
6 : [157, 234, 50] # RoadLines
7 : [0, 0, 255] # Road
8 : [255, 255, 255] # Sidewalk
9 : [0, 155, 0] # Vegetation
10 : [255, 0, 0] # Vehicle
11 : [102, 102, 156] # Wall
12 : [220, 220, 0] # TrafficSign
13 : [70, 130, 180] # Sky
14 : [255, 255, 255] # Ground
15 : [150, 100, 100] # Bridge
16 : [230, 150, 140] # RailTrack
17 : [180, 165, 180] # GuardRail
18 : [250, 170, 30] # TrafficLight
19 : [110, 190, 160] # Static
20 : [170, 120, 50] # Dynamic
21 : [45, 60, 150] # Water
22 : [145, 170, 100] # Terrain
learning_map :
0 : 0
1 : 1
2 : 2
3 : 3
4 : 4
5 : 5
6 : 6
7 : 6
8 : 8
9 : 9
10: 10
11 : 2
12 : 5
13 : 3
14 : 7
15 : 3
16 : 3
17 : 2
18 : 5
19 : 3
20 : 3
21 : 3
22 : 7
remap_color_map:
0 : [255, 255, 255] # None
1 : [255, 200, 0] # Building
2 : [255, 120, 50] # Fences
3 : [55, 90, 80] # Other
4 : [255, 30, 30] # Pedestrian
5 : [255, 240, 150] # Pole
6 : [255, 0, 255] # Road
7 : [175, 0, 75] # Ground
8 : [75, 0, 75] # Sidewalk
9 : [0, 175, 0] # Vegetation
10 : [100, 150, 245] # Vehicle
label_to_names:
0 : Free
1 : Building
2 : Barrier
3 : Other
4 : Pedestrian
5 : Pole
6 : Road
7 : Ground
8 : Sidewalk
9 : Vegetation
10 : Vehicle
content :
0 : 4166593275
1 : 42309744
2 : 8550180
3 : 478193
4 : 905663
5 : 2801091
6 : 6452733
7 : 229316930
8 : 112863867
9 : 29816894
10: 13839655
11 : 15581458
12 : 221821
13 : 0
14 : 7931550
15 : 467989
16 : 3354
17 : 9201043
18 : 61011
19 : 3796746
20 : 3217865
21 : 215372
22 : 79669695
remap_content :
0 : 4.16659328e+09
1 : 4.23097440e+07
2 : 3.33326810e+07
3 : 8.17951900e+06
4 : 9.05663000e+05
5 : 3.08392300e+06
6 : 2.35769663e+08
7 : 8.76012450e+07
8 : 1.12863867e+08
9 : 2.98168940e+07
10 : 1.38396550e+07
================================================
FILE: datasets/carla_dataset.py
================================================
import os
import numpy as np
import random
import json
import yaml
import torch
import numba as nb
from torch.utils.data import Dataset
base_dir = os.path.dirname(__file__)
config_file = os.path.join(base_dir, 'carla.yaml')
carla_config = yaml.safe_load(open(config_file, 'r'))
LABELS_REMAP = carla_config["learning_map"]
REMAP_FREQUENCIES = carla_config["remap_content"]
FREQUENCIES= carla_config["content"]
LABELS_REMAP = np.asarray(list(LABELS_REMAP.values()))
frequencies_cartesian = np.asarray(list(FREQUENCIES.values()))
remap_frequencies_cartesian = np.asarray(list(REMAP_FREQUENCIES.values()))
class CarlaDataset(Dataset):
"""Carla Simulation Dataset for 3D mapping project
Access to the processed data, including evaluation labels predictions velodyne poses times
"""
def __init__(self, directory,
voxelize_input=True,
binary_counts=True,
random_flips=False,
remap=True,
num_frames=1,
transform_pose=True,
get_gt=True,
):
'''Constructor.
Parameters:
directory: directory to the dataset
'''
self.get_gt = get_gt
self.voxelize_input = voxelize_input
self.binary_counts = binary_counts
self._directory = directory
self._num_frames = num_frames
self.random_flips = random_flips
self.remap = remap
self.transform_pose = transform_pose
self.sparse_output = True
self._scenes = sorted(os.listdir(self._directory))
self._scenes = [os.path.join(scene, "cartesian") for scene in self._scenes]
self._num_scenes = len(self._scenes)
self._num_frames_scene = []
param_file = os.path.join(self._directory, self._scenes[0], 'evaluation', 'params.json')
with open(param_file) as f:
self._eval_param = json.load(f)
self._out_dim = self._eval_param['num_channels']
self._grid_size = self._eval_param['grid_size']
self.grid_dims = np.asarray(self._grid_size)
self._eval_size = list(np.uint32(self._grid_size))
self.coor_ranges = self._eval_param['min_bound'] + self._eval_param['max_bound']
self.voxel_sizes = [abs(self.coor_ranges[3] - self.coor_ranges[0]) / self._grid_size[0],
abs(self.coor_ranges[4] - self.coor_ranges[1]) / self._grid_size[1],
abs(self.coor_ranges[5] - self.coor_ranges[2]) / self._grid_size[2]]
self.min_bound = np.asarray(self.coor_ranges[:3])
self.max_bound = np.asarray(self.coor_ranges[3:])
self.voxel_sizes = np.asarray(self.voxel_sizes)
self._velodyne_list = []
self._label_list = []
self._pred_list = []
self._eval_labels = []
self._eval_counts = []
self._frames_list = []
self._timestamps = []
self._poses = []
for scene in self._scenes:
velodyne_dir = os.path.join(self._directory, scene, 'velodyne')
label_dir = os.path.join(self._directory, scene, 'labels')
pred_dir = os.path.join(self._directory, scene, 'predictions')
eval_dir = os.path.join(self._directory, scene, 'evaluation')
self._num_frames_scene.append(len(os.listdir(velodyne_dir)))
frames_list = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(velodyne_dir))]
self._frames_list.extend(frames_list)
self._velodyne_list.extend([os.path.join(velodyne_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])
self._label_list.extend([os.path.join(label_dir, str(frame).zfill(6)+'.label') for frame in frames_list])
self._pred_list.extend([os.path.join(pred_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])
self._eval_labels.extend([os.path.join(eval_dir, str(frame).zfill(6)+'.label') for frame in frames_list])
self._eval_counts.extend([os.path.join(eval_dir, str(frame).zfill(6) + '.bin') for frame in frames_list])
self._timestamps.append(np.loadtxt(os.path.join(self._directory, scene, 'times.txt')))
self._poses.append(np.loadtxt(os.path.join(self._directory, scene, 'poses.txt')))
# for poses and timestamps
self._timestamps = np.array(self._timestamps).reshape(sum(self._num_frames_scene))
self._poses = np.array(self._poses).reshape(sum(self._num_frames_scene), 12)
self._cum_num_frames = np.cumsum(np.array(self._num_frames_scene) - self._num_frames + 1)
# Use all frames, if there is no data then zero pad
def __len__(self):
return sum(self._num_frames_scene)
def collate_fn(self, data):
voxel_batch = [bi[0] for bi in data]
output_batch = [bi[1] for bi in data]
counts_batch = [bi[2] for bi in data]
return voxel_batch, output_batch, counts_batch
def points_to_voxels(self, voxel_grid, points, t_i):
# Valid voxels (make sure to clip)
voxels = np.floor((points - self.min_bound) / self.voxel_sizes).astype(np.int32)
# Clamp to account for any floating point errors
maxes = np.reshape(self.grid_dims - 1, (1, 3))
mins = np.zeros_like(maxes)
voxels = np.clip(voxels, mins, maxes).astype(np.int32)
# This line is needed to create a mask with number of points, not just binary occupied
if self.binary_counts:
voxel_grid[t_i, voxels[:, 0], voxels[:, 1], voxels[:, 2]] += 1
else:
unique_voxels, counts = np.unique(voxels, return_counts=True, axis=0)
unique_voxels = unique_voxels.astype(np.int32)
voxel_grid[t_i, unique_voxels[:, 0], unique_voxels[:, 1], unique_voxels[:, 2]] += counts
return voxel_grid
def get_pose(self, idx):
pose = np.zeros((4, 4))
pose[3, 3] = 1
pose[:3, :4] = self._poses[idx].reshape(3, 4)
return pose
def __getitem__(self, idx):
# -1 indicates no data
# the final index is the output
idx_range = self.find_horizon(idx)
if self.transform_pose:
ego_pose = self.get_pose(idx_range[-1])
to_ego = np.linalg.inv(ego_pose)
if self.voxelize_input:
voxel_input = np.zeros((idx_range.shape[0], int(self.grid_dims[0]), int(self.grid_dims[1]), int(self.grid_dims[2])), dtype=np.float32)
t_i = 0
for i in idx_range:
if i == -1: # Zero pad
points = np.zeros((1, 3), dtype=np.float32)
else:
points = np.fromfile(self._velodyne_list[i],dtype=np.float32).reshape(-1, 4)[:, :3]
if self.transform_pose:
to_world = self.get_pose(i)
relative_pose = np.matmul(to_ego, to_world)
points = np.dot(relative_pose[:3, :3], points.T).T + relative_pose[:3, 3]
valid_point_mask= np.all((points < self.max_bound) & (points >= self.min_bound), axis=1)
valid_points = points[valid_point_mask, :]
if self.voxelize_input:
voxel_input = self.points_to_voxels(voxel_input, valid_points, t_i)
t_i += 1
if self.get_gt:
output = np.fromfile(self._eval_labels[idx_range[-1]],dtype=np.uint32).reshape(self._eval_size).astype(np.uint8)
counts = np.fromfile(self._eval_counts[idx_range[-1]],dtype=np.float32).reshape(self._eval_size)
else:
output = None
counts = None
if self.voxelize_input and self.random_flips:
# X flip
if np.random.randint(2):
output = np.flip(output, axis=0)
counts = np.flip(counts, axis=0)
voxel_input = np.flip(voxel_input, axis=1) # Because there is a time dimension
# Y Flip
if np.random.randint(2):
output = np.flip(output, axis=1)
counts = np.flip(counts, axis=1)
voxel_input = np.flip(voxel_input, axis=2) # Because there is a time dimension
if self.remap:
output = LABELS_REMAP[output].astype(np.uint8)
return voxel_input, output, counts
# no enough frames
def find_horizon(self, idx):
end_idx = idx
idx_range = np.arange(idx-self._num_frames, idx)+1
diffs = np.asarray([int(self._frames_list[end_idx]) - int(self._frames_list[i]) for i in idx_range])
good_difs = -1 * (np.arange(-self._num_frames, 0) + 1)
idx_range[good_difs != diffs] = -1
return idx_range
================================================
FILE: datasets/data.py
================================================
import os
import math
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets.carla_dataset import *
dataset_choices = {'carla', 'kitti'}
def get_data_id(args):
return '{}'.format(args.dataset)
def get_class_weights(freq):
'''
Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf)
'''
epsilon_w = 0.001 # eps to avoid zero division
weights = torch.from_numpy(1 / np.log(freq + epsilon_w))
return weights
def get_data(args):
assert args.dataset in dataset_choices
if args.dataset == 'carla':
train_dir = os.path.join(args.dataset_dir, "Train")
val_dir = os.path.join(args.dataset_dir, "Val")
test_dir = os.path.join(args.dataset_dir, "Test")
x_dim = 128
y_dim = 128
z_dim = 8
data_shape = [x_dim, y_dim, z_dim]
args.data_shape= data_shape
binary_counts = True
transform_pose = True
remap = True
if remap:
class_frequencies = remap_frequencies_cartesian
args.num_classes = 11
else:
args.num_classes = 23
comp_weights = get_class_weights(class_frequencies).to(torch.float32)
seg_weights = get_class_weights(class_frequencies[1:]).to(torch.float32)
train_ds = CarlaDataset(directory=train_dir, random_flips=True, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)
coor_ranges = train_ds._eval_param['min_bound'] + train_ds._eval_param['max_bound']
voxel_sizes = [abs(coor_ranges[3] - coor_ranges[0]) / x_dim,
abs(coor_ranges[4] - coor_ranges[1]) / y_dim,
abs(coor_ranges[5] - coor_ranges[2]) / z_dim] # since BEV
val_ds = CarlaDataset(directory=val_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)
test_ds = CarlaDataset(directory=test_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)
if args is not None and args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds, shuffle=False)
train_iters = len(train_sampler) // args.batch_size
val_iters = len(val_sampler) // args.batch_size
else:
train_sampler = None
val_sampler = None
train_iters = len(train_ds) // args.batch_size
val_iters = len(val_ds) // args.batch_size
dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, collate_fn=train_ds.collate_fn, num_workers=args.num_workers)
dataloader_val = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, sampler=val_sampler, collate_fn=val_ds.collate_fn, num_workers=args.num_workers)
dataloader_test = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=test_ds.collate_fn, num_workers=args.num_workers)
else:
raise NotImplementedError("Wrong `dataset` has come. Other datasets are not supported.")
return dataloader, dataloader_val, dataloader_test, args.num_classes, comp_weights, seg_weights, train_sampler
================================================
FILE: layers/Ablation/wo_diffusion.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder
class wo_diff(torch.nn.Module):
def __init__(self, args, multi_criterion) -> None:
super(wo_diff, self).__init__()
self.args = args
if self.args.dataset == 'kitti':
init_size = args.init_size
elif self.args.dataset == 'carla':
init_size = args.init_size
self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
self.multi_criterion = multi_criterion
def device(self):
return self.encoder.device
def forward(self, x, input_ten):
latent = self.encoder(input_ten, out_conv=False)
recons = self.decoder(latent, in_conv=False)
recons_loss = self.multi_criterion(recons, x)
return recons_loss
def sample(self, x):
latent = self.encoder(x, out_conv=False)
recons = self.decoder(latent, in_conv=False)
recons = recons.argmax(1)
return recons
================================================
FILE: layers/Latent_Level/stage1/model.py
================================================
import numpy as np
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum
def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)
def conv1x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)
def conv1x3x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)
def conv3x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)
def conv3x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)
class Asymmetric_Residual_Block(nn.Module):
def __init__(self, in_filters, out_filters):
super(Asymmetric_Residual_Block, self).__init__()
self.conv1 = conv1x3x3(in_filters, out_filters)
self.act1 = nn.LeakyReLU()
self.conv1_2 = conv3x1x3(out_filters, out_filters)
self.act1_2 = nn.LeakyReLU()
self.conv2 = conv3x1x3(in_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.conv3 = conv1x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
if in_filters<32 :
self.GroupNorm = nn.GroupNorm(8, in_filters)
self.bn0 = nn.GroupNorm(8, out_filters)
self.bn0_2 = nn.GroupNorm(8, out_filters)
self.bn1 = nn.GroupNorm(8, out_filters)
self.bn2 = nn.GroupNorm(8, out_filters)
else :
self.GroupNorm = nn.GroupNorm(32, in_filters)
self.bn0 = nn.GroupNorm(32, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.bn1 = nn.GroupNorm(32, out_filters)
self.bn2 = nn.GroupNorm(32, out_filters)
def forward(self, x):
shortcut = self.conv1(x)
shortcut = self.act1(shortcut)
shortcut = self.bn0(shortcut)
shortcut = self.conv1_2(shortcut)
shortcut = self.act1_2(shortcut)
shortcut = self.bn0_2(shortcut)
resA = self.conv2(x)
resA = self.act2(resA)
resA = self.bn1(resA)
resA = self.conv3(resA)
resA = self.act3(resA)
resA = self.bn2(resA)
resA += shortcut
return resA
class DownBlock(nn.Module):
def __init__(self, in_filters, out_filters, pooling=True, drop_out=True, height_pooling=False):
super(DownBlock, self).__init__()
self.pooling = pooling
self.drop_out = drop_out
self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)
if pooling:
if height_pooling:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)
else:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)
def forward(self, x):
resA = self.residual_block(x)
if self.pooling:
resB = self.pool(resA)
return resB, resA
else:
return resA
class UpBlock(nn.Module):
def __init__(self, in_filters, out_filters, height_pooling):
super(UpBlock, self).__init__()
# self.drop_out = drop_out
self.trans_dilao = conv3x3x3(in_filters, out_filters)
self.trans_act = nn.LeakyReLU()
self.conv1 = conv1x3x3(out_filters, out_filters)
self.act1 = nn.LeakyReLU()
self.conv2 = conv3x1x3(out_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.conv3 = conv3x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
if out_filters<32 :
self.trans_bn = nn.GroupNorm(8, out_filters)
self.bn1 = nn.GroupNorm(8, out_filters)
self.bn2 = nn.GroupNorm(8, out_filters)
self.bn3 = nn.GroupNorm(8, out_filters)
else :
self.trans_bn = nn.GroupNorm(32, out_filters)
self.bn1 = nn.GroupNorm(32, out_filters)
self.bn2 = nn.GroupNorm(32, out_filters)
self.bn3 = nn.GroupNorm(32, out_filters)
if height_pooling :
self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
else :
self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
def forward(self, x, skip=False):
if skip :
x, residual = x
upA = self.trans_dilao(x)
upA = self.trans_act(upA)
upA = self.trans_bn(upA)
upA = self.up_subm(upA)
if skip :
upA += residual
upE = self.conv1(upA)
upE = self.act1(upE)
upE = self.bn1(upE)
upE = self.conv2(upE)
upE = self.act2(upE)
upE = self.bn2(upE)
upE = self.conv3(upE)
upE = self.act3(upE)
upE = self.bn3(upE)
return upE
class DDCM(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
super(DDCM, self).__init__()
self.conv1 = conv3x1x1(in_filters, out_filters)
self.act1 = nn.Sigmoid()
self.conv1_2 = conv1x3x1(in_filters, out_filters)
self.act1_2 = nn.Sigmoid()
self.conv1_3 = conv1x1x3(in_filters, out_filters)
self.act1_3 = nn.Sigmoid()
if in_filters<32 :
self.bn0 = nn.GroupNorm(8, out_filters)
self.bn0_2 = nn.GroupNorm(8, out_filters)
self.bn0_3 = nn.GroupNorm(8, out_filters)
else :
self.bn0 = nn.GroupNorm(32, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.bn0_3 = nn.GroupNorm(32, out_filters)
def forward(self, x):
shortcut = self.conv1(x)
shortcut = self.bn0(shortcut)
shortcut = self.act1(shortcut)
shortcut2 = self.conv1_2(x)
shortcut2 = self.bn0_2(shortcut2)
shortcut2 = self.act1_2(shortcut2)
shortcut3 = self.conv1_3(x)
shortcut3 = self.bn0_3(shortcut3)
shortcut3 = self.act1_3(shortcut3)
shortcut = shortcut + shortcut2 + shortcut3
shortcut = shortcut * x
return shortcut
def l2norm(t):
return F.normalize(t, dim = -1)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_qkv = conv1x1(dim, dim*3, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x):
b, c, h, w, Z = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class C_Encoder(nn.Module):
def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True):
super(C_Encoder, self).__init__()
self.nclasses = nclasses
self.args = args
self.l_size = l_size
self.attention = attention
self.embedding = nn.Embedding(nclasses, init_size)
self.A = Asymmetric_Residual_Block(init_size, init_size)
self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)
self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)
self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)
self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)
if self.l_size == '32322':
self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.attention = Attention(4 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.out = nn.Conv3d(4 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
elif self.l_size == '16162':
self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
self.attention = Attention(8 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
self.out = nn.Conv3d(8 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
elif self.l_size == '882':
self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
self.attention = Attention(16 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
self.out = nn.Conv3d(16 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
else:
raise NotImplementedError("Unsupported `l_size` has come")
def forward(self, x, out_conv=True):
x = self.embedding(x)
x = x.permute(0, 4, 1, 2, 3)
x = self.A(x)
x, down1b = self.downBlock1(x)
x, down2b = self.downBlock2(x)
if self.l_size == '882':
x, down3b = self.downBlock3(x)
x, down4b = self.downBlock4(x)
elif self.l_size == '16162':
x, down3b = self.downBlock3(x)
if self.attention :
x = self.midBlock1(x) # (4, 128, 32, 32, 2)
x = self.attention(x)
x = self.midBlock2(x) # (4, 128, 32, 32, 2)
if out_conv :
x = self.out(x)
return x
class C_Decoder(nn.Module):
def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True):
super(C_Decoder, self).__init__()
self.nclasses = nclasses
self.args = args
self.l_size = l_size
self.attention = attention
if l_size == '882':
self.conv_in = nn.Conv3d(nclasses, 16 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
self.attention = Attention(16 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
elif l_size == '16162':
self.conv_in = nn.Conv3d(nclasses, 8 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
self.attention = Attention(8 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
elif (l_size =='32322'):
self.conv_in = nn.Conv3d(nclasses, 4 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.attention = Attention(4 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)
self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)
self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)
self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)
self.DDCM = DDCM(2 * init_size, 2 * init_size)
self.logits = nn.Conv3d(4 * init_size, self.nclasses, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x, in_conv=True):
if in_conv :
x = self.conv_in(x)
if self.attention :
x = self.midBlock1(x)
x = self.attention(x)
x = self.midBlock2(x)
if self.l_size == '882':
x = self.upBlock4(x)
x = self.upBlock3(x)
elif self.l_size == '16162':
x = self.upBlock3(x)
x = self.upBlock2(x)
up1 = self.upBlock1(x)
up0 = self.DDCM(up1)
up = torch.cat((up1, up0), 1)
logits = self.logits(up)
return logits
class Completion(nn.Module):
def __init__(self, args, num_class = 11, init_size=32):
super(Completion, self).__init__()
self.args = args
self.num_class = num_class
self.init_size = init_size
self.embedding = nn.Embedding(self.num_class, init_size)
self.A = Asymmetric_Residual_Block(init_size, init_size)
self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)
self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)
self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)
self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)
self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
self.attention = Attention(16 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)
self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)
self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)
self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)
self.DDCM = DDCM(2 * init_size, 2 * init_size)
self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 4, 1, 2, 3)
x = self.A(x)
down1c, down1b = self.downBlock1(x)
down2c, down2b = self.downBlock2(down1c)
down3c, down3b = self.downBlock3(down2c)
down4c, down4b = self.downBlock4(down3c)
down4c = self.midBlock1(down4c)
down4c = self.attention(down4c)
down4c = self.midBlock2(down4c)
up4 = self.upBlock4((down4c, down4b), skip=True)
up3 = self.upBlock3((up4, down3b), skip=True)
up2 = self.upBlock2((up3, down2b), skip=True)
up1 = self.upBlock1((up2, down1b), skip=True)
up0 = self.DDCM(up1)
up = torch.cat((up1, up0), 1)
logits = self.logits(up)
return logits
================================================
FILE: layers/Latent_Level/stage1/vector_quantizer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
class VectorQuantizer(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
beta: float = 0.25):
super(VectorQuantizer, self).__init__()
self.K = num_embeddings
self.D = embedding_dim
self.beta = beta
self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
def forward(self, z: torch.tensor, point=False) -> torch.tensor: # latents (8, 128, 8, 8, 2)
z = z.permute(0, 2, 3, 4, 1).contiguous() # [B x D x H x W x Z] -> [B x H x W x Z x D]
latents_shape = z.shape # ( 8, 8, 8, 2, 128 )
flat_latents = z.view(-1, self.D) # [BHWZ x D] = [1024, 128]
# Compute L2 distance between latents and embedding weights
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - \
2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHWZ x K]
# Get the encoding that has the min distance
min_encoding_indices = torch.argmin(dist, dim=1).unsqueeze(1) # [BHWZ, 1]
z_q = self.embedding(min_encoding_indices).view(z.shape)
# Compute the VQ Losses
commitment_loss = F.mse_loss(z_q.detach(), z)
embedding_loss = F.mse_loss(z_q, z.detach())
if point :
vq_loss = commitment_loss * self.beta
else :
vq_loss = commitment_loss * self.beta + embedding_loss
# Add the residue back to the latents
z_q = z + (z_q - z).detach()
return z_q.permute(0, 4, 1, 2, 3).contiguous(), vq_loss, min_encoding_indices, latents_shape
def codebook_to_embedding(self, encoding_inds, latents_shape): # latents (16, 512, 8, 8, 2)
# Convert to one-hot encodings
z_q = self.embedding(encoding_inds).view(latents_shape)
return z_q.permute(0, 4, 1, 2, 3).contiguous()
================================================
FILE: layers/Latent_Level/stage1/vqvae.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math
from utils.loss import lovasz_softmax
from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder
from layers.Latent_Level.stage1.vector_quantizer import VectorQuantizer
class vqvae(torch.nn.Module):
def __init__(self, args, multi_criterion) -> None:
super(vqvae, self).__init__()
self.args = args
init_size = args.init_size
embedding_dim = int(self.args.num_classes)
self.VQ = VectorQuantizer(num_embeddings = int(self.args.num_classes)*int(self.args.vq_size), embedding_dim = embedding_dim)
self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
self.quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)
self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
self.post_quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)
self.multi_criterion = multi_criterion
def device(self):
return self.encoder.device
def encode(self, x):
latent = self.encoder(x)
latent = self.quant_conv(latent)
return latent
def vector_quantize(self, latent):
quantized_latent, vq_loss, quantized_latent_ind, latents_shape = self.VQ(latent)
return quantized_latent, vq_loss, quantized_latent_ind, latents_shape
def coodbook(self,quantized_latent_ind, latents_shape):
quantized_latent = self.VQ.codebook_to_embedding(quantized_latent_ind.view(-1,1), latents_shape)
return quantized_latent
def decode(self, quantized_latent):
quantized_latent = self.post_quant_conv(quantized_latent)
recons = self.decoder(quantized_latent)
return recons
def forward(self, x, input_ten):
latent = self.encode(x)
quantized_latent, vq_loss, _, _ = self.vector_quantize(latent)
recons = self.decode(quantized_latent)
recons_loss = self.multi_criterion(recons, x)
loss = recons_loss + vq_loss
return loss
def sample(self, x):
latent = self.encode(x)
quantized_latent, _, _, _ = self.vector_quantize(latent)
recons = self.decode(quantized_latent)
recons = recons.argmax(1)
return recons
================================================
FILE: layers/Latent_Level/stage2/Gen_diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Latent_Level.stage2.gen_denoise import Denoise
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8
def sum_except_batch(x, num_dims=1):
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def exists(x):
return x is not None
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def log_onehot_to_index(log_x):
return log_x.argmax(1)
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = np.clip(alphas, a_min=0.001, a_max=1.)
alphas = np.sqrt(alphas)
return alphas
class latent_diffusion(torch.nn.Module):
def __init__(self, args, VAE_DENSE, multi_criterion,
auxiliary_loss_weight=0.0005, adaptive_auxiliary_loss=True):
super(latent_diffusion, self).__init__()
self.args = args
self.num_classes = self.args.num_classes * self.args.vq_size
self.denoise = Denoise(args= self.args, num_class = self.num_classes)
self.num_timesteps = self.args.diffusion_steps
self.auxiliary_loss_weight = auxiliary_loss_weight
self.adaptive_auxiliary_loss = adaptive_auxiliary_loss
self.VAE_DENSE = VAE_DENSE
self.multi_criterion = multi_criterion
alphas = cosine_beta_schedule(self.num_timesteps )
alphas = torch.tensor(alphas.astype('float64'))
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('log_alpha', log_alpha.float())
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
def device(self):
return self.denoise.device
def multinomial_kl(self, log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
def q_pred_one_timestep(self, log_x_t, t):
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start, t):
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def predict_start(self, log_x_t, t):
x_t = log_onehot_to_index(log_x_t)
out = self.denoise(x_t, t)
assert out.size(0) == x_t.size(0)
assert out.size(1) == self.num_classes
assert out.size()[2:] == x_t.size()[1:]
log_pred = F.log_softmax(out, dim=1)
return log_pred
def q_posterior(self, log_x_start, log_x_t, t):
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
log_EV_xtmin_given_xt_given_xstart = \
unnormed_logprobs \
- torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x, t):
log_x0_recon = self.predict_start(log_x, t=t)
log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
return log_model_pred, log_x0_recon
def log_sample_categorical(self, logits):
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + logits).argmax(dim=1)
log_sample = index_to_log_onehot(sample, self.num_classes)
return log_sample
def q_sample(self, log_x_start, t):
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
return log_sample
def kl_prior(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device).long()
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = Lt_sqrt / Lt_sqrt.sum()
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def forward(self, x, input_data):
b, device = x.size(0), x.device
self.shape = x.size()[1:]
latent = self.VAE_DENSE.encode(x)
_, _, dense_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)
reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]
t, pt = self.sample_time(b, device, 'importance')
log_x_start = index_to_log_onehot(dense_ind.view(reshape_size), self.num_classes)
log_x_t = self.q_sample(log_x_start=log_x_start, t=t) # log_x_t : (8,551,8,8,2)
log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
decoder_nll = -log_categorical(log_x_start, log_model_prob)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).float()
kl_loss = mask * decoder_nll + (1. - mask) * kl
if self.training:
Lt2 = kl_loss.pow(2)
Lt2_prev = self.Lt_history.gather(dim=0, index=t)
new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
loss = kl_loss / pt + kl_prior
kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
kl_aux = sum_except_batch(kl_aux)
kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
if self.adaptive_auxiliary_loss:
addition_loss_weight = (1-t/self.num_timesteps) + 1.0
else:
addition_loss_weight = 1.0
aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
loss += aux_loss
loss = -loss.sum() / (math.log(2) * dense_ind.view(reshape_size).shape.numel())
x0 = log_onehot_to_index(F.log_softmax(log_x0_recon, dim=1))
return -loss
def sample(self, x):
device = self.log_alpha.device
self.shape = x.size()[1:]
x = torch.randint(self.args.num_classes, size=x.size()).to(device)
latent = self.VAE_DENSE.encode(x)
_, _, sparse_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)
reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]
log_z = index_to_log_onehot(sparse_ind.view(reshape_size), self.num_classes) # log_x_t : (8,551,8,8,2)
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)
uniform = torch.rand_like(log_model_prob)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
pre_sample = gumbel_noise + log_model_prob
sample = pre_sample.argmax(dim=1) # (32, 1, 32, 64)
log_z = index_to_log_onehot(sample, self.num_classes)
vq_ind = log_onehot_to_index(log_z)
vq_latent = self.VAE_DENSE.coodbook(vq_ind.view(-1,1), latents_shape)
recons = self.VAE_DENSE.decode(vq_latent)
recons = recons.argmax(1)
return recons
================================================
FILE: layers/Latent_Level/stage2/gen_denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum
def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)
def conv1x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)
def conv1x3x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)
def conv3x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)
def conv3x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)
class Asymmetric_Residual_Block(nn.Module):
def __init__(self, in_filters, out_filters, time_filters=128):
super(Asymmetric_Residual_Block, self).__init__()
self.GroupNorm = nn.GroupNorm(32, in_filters)
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
self.bn0 = nn.GroupNorm(32, out_filters)
self.act1 = nn.LeakyReLU()
self.conv1_2 = conv3x1x3(out_filters, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.act1_2 = nn.LeakyReLU()
self.conv2 = conv3x1x3(in_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.bn1 = nn.GroupNorm(32, out_filters)
self.conv3 = conv1x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
self.bn2 = nn.GroupNorm(32, out_filters)
def forward(self, x, t):
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
x = self.GroupNorm(x) * (1 + scale) + shift
shortcut = self.conv1(x)
shortcut = self.act1(shortcut)
shortcut = self.bn0(shortcut)
shortcut = self.conv1_2(shortcut)
shortcut = self.act1_2(shortcut)
shortcut = self.bn0_2(shortcut)
resA = self.conv2(x)
resA = self.act2(resA)
resA = self.bn1(resA)
resA = self.conv3(resA)
resA = self.act3(resA)
resA = self.bn2(resA)
resA += shortcut
return resA
class DDCM(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
super(DDCM, self).__init__()
self.conv1 = conv3x1x1(in_filters, out_filters)
self.bn0 = nn.GroupNorm(32, out_filters)
self.act1 = nn.Sigmoid()
self.conv1_2 = conv1x3x1(in_filters, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.act1_2 = nn.Sigmoid()
self.conv1_3 = conv1x1x3(in_filters, out_filters)
self.bn0_3 = nn.GroupNorm(32, out_filters)
self.act1_3 = nn.Sigmoid()
def forward(self, x):
shortcut = self.conv1(x)
shortcut = self.bn0(shortcut)
shortcut = self.act1(shortcut)
shortcut2 = self.conv1_2(x)
shortcut2 = self.bn0_2(shortcut2)
shortcut2 = self.act1_2(shortcut2)
shortcut3 = self.conv1_3(x)
shortcut3 = self.bn0_3(shortcut3)
shortcut3 = self.act1_3(shortcut3)
shortcut = shortcut + shortcut2 + shortcut3
shortcut = shortcut * x
return shortcut
def l2norm(t):
return F.normalize(t, dim = -1)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_qkv = conv1x1(dim, dim*3, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x):
b, c, h, w, Z = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class Cross_Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_q = conv1x1(dim, dim, stride=1)
self.to_k = conv1x1(dim, dim, stride=1)
self.to_v = conv1x1(dim, dim, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x, cond_x):
b, c, h, w, Z = x.shape
q = self.to_q(x)
k = self.to_k(cond_x)
v = self.to_v(cond_x)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class DownBlock(nn.Module):
def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,
pooling=True, drop_out=True, height_pooling=False):
super(DownBlock, self).__init__()
self.pooling = pooling
self.drop_out = drop_out
self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)
if pooling:
if height_pooling:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,
padding=1, bias=False)
else:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
padding=1, bias=False)
def forward(self, x, t):
resA = self.residual_block(x, t)
if self.pooling:
resB = self.pool(resA)
return resB, resA
else:
return resA
class UpBlock(nn.Module):
def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):
super(UpBlock, self).__init__()
# self.drop_out = drop_out
self.trans_dilao = conv3x3x3(in_filters, in_filters)
self.trans_act = nn.LeakyReLU()
self.trans_bn = nn.GroupNorm(32, in_filters)
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
self.act1 = nn.LeakyReLU()
self.bn1 = nn.GroupNorm(32, out_filters)
self.conv2 = conv3x1x3(out_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.bn2 = nn.GroupNorm(32, out_filters)
self.conv3 = conv3x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
self.bn3 = nn.GroupNorm(32, out_filters)
if height_pooling :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
else :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
def forward(self, x, residual, t):
upA = self.trans_dilao(x)
upA = self.trans_act(upA)
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
upA = self.trans_bn(upA) * (1 + scale) + shift
## upsample
upA = self.up_subm(upA)
upA += residual
upE = self.conv1(upA)
upE = self.act1(upE)
upE = self.bn1(upE)
upE = self.conv2(upE)
upE = self.act2(upE)
upE = self.bn2(upE)
upE = self.conv3(upE)
upE = self.act3(upE)
upE = self.bn3(upE)
return upE
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class Denoise(nn.Module):
def __init__(self, args, num_class = 11, init_size=32, discrete=True):
super(Denoise, self).__init__()
self.args = args
self.discrete = discrete
self.num_class = num_class
self.init_size = init_size
self.time_size = init_size*4
self.time_embed = nn.Sequential(
nn.Linear(init_size, self.time_size),
nn.SiLU(),
nn.Linear(self.time_size, self.time_size),
)
self.embedding = nn.Embedding(self.num_class, init_size)
self.conv_in = nn.Conv3d(init_size, init_size, kernel_size=1, stride=1)
self.A = Asymmetric_Residual_Block(init_size, init_size)
self.midBlock1_1 = Asymmetric_Residual_Block(init_size, 2 * init_size)
self.attention1 = Attention(2 * init_size, 4)
self.midBlock1_2 = Asymmetric_Residual_Block(2 * init_size, 2 * init_size)
self.downBlock2 = DownBlock(init_size*2, 2 * init_size, 0.2, height_pooling=False)
self.downBlock3 = DownBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=False)
self.midBlock2_1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.attention2 = Attention(4 * init_size, 4)
self.midBlock2_2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
self.upBlock0 = UpBlock(4 * init_size, 2 * init_size, height_pooling=False)
self.upBlock1 = UpBlock(2 * init_size, init_size, height_pooling=False)
self.midBlock3_1 = Asymmetric_Residual_Block(init_size, init_size)
self.attention3 = Attention(init_size, 4)
self.midBlock3_2 = Asymmetric_Residual_Block(init_size, init_size)
self.DDCM = DDCM(init_size, init_size)
self.logits = nn.Sequential(
nn.Conv3d(2 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True),
)
def forward(self, x, t):
x = self.embedding(x)
x = x.permute(0, 4, 1, 2, 3)
x = self.conv_in(x)
t = self.time_embed(timestep_embedding(t, self.init_size))
ret = self.A(x, t)
mid1 = self.midBlock1_1(ret, t)
att = self.attention1(mid1)
mid2 = self.midBlock1_2(att, t)
down1c, down1b = self.downBlock2(mid2, t)
down2c, down2b = self.downBlock3(down1c, t)
d_mid2 = self.midBlock2_1(down2c, t)
d_att = self.attention2(d_mid2)
d_mid1 = self.midBlock2_2(d_att, t)
up3e = self.upBlock0(d_mid1, down2b, t)
up2e = self.upBlock1(up3e, down1b, t)
u_mid2 = self.midBlock3_1(up2e, t)
u_att = self.attention3(u_mid2)
u_mid1 = self.midBlock3_2(u_att, t)
up0e = self.DDCM(u_mid1)
up0e = torch.cat((up0e, up2e), 1)
logits = self.logits(up0e)
return logits
================================================
FILE: layers/Voxel_Level/Con_Diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Voxel_Level.denoise import Denoise
from utils.loss import *
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8
def sum_except_batch(x, num_dims=1):
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def exists(x):
return x is not None
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def log_onehot_to_index(log_x):
return log_x.argmax(1)
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = np.clip(alphas, a_min=0.001, a_max=1.)
alphas = np.sqrt(alphas)
return alphas
class Con_Diffusion(torch.nn.Module):
def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):
super(Con_Diffusion, self).__init__()
#self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)
self.args = args
self.num_classes = self.args.num_classes
self.num_timesteps = self.args.diffusion_steps
self.recon_loss = self.args.recon_loss
if args.dataset == 'carla':
self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes)
elif args.dataset=='kitti':
self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes, init_size=16)
self.auxiliary_loss_weight = auxiliary_loss_weight
self.adaptive_auxiliary_loss = adaptive_auxiliary_loss
self.multi_criterion = multi_criterion
alphas = cosine_beta_schedule(self.num_timesteps )
alphas = torch.tensor(alphas.astype('float64'))
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('log_alpha', log_alpha.float())
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
def device(self):
return self.denoise_fn.device
def multinomial_kl(self, log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
def q_pred_one_timestep(self, log_x_t, t):
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start, t):
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def predict_start(self, log_x_t, t, cond):
x_t = log_onehot_to_index(log_x_t)
out = self._denoise_fn(x_t, cond, t)
log_pred = F.log_softmax(out, dim=1)
return log_pred
def q_posterior(self, log_x_start, log_x_t, t):
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x, t, cond):
log_x0_recon = self.predict_start(log_x, t, cond)
log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
return log_model_pred, log_x0_recon
def log_sample_categorical(self, logits):
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + logits).argmax(dim=1)
log_sample = index_to_log_onehot(sample, self.num_classes)
return log_sample
def q_sample(self, log_x_start, t):
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
return log_sample
def kl_prior(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device).long()
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = Lt_sqrt / Lt_sqrt.sum()
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def forward(self, x, voxel_input):
b, device = x.size(0), x.device
self.shape = x.size()[1:]
t, pt = self.sample_time(b, device, 'importance')
log_x_start = index_to_log_onehot(x, self.num_classes)
log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)
log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t, cond=voxel_input)
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
decoder_nll = -log_categorical(log_x_start, log_model_prob)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).float()
kl_loss = mask * decoder_nll + (1. - mask) * kl
if self.training:
Lt2 = kl_loss.pow(2)
Lt2_prev = self.Lt_history.gather(dim=0, index=t)
new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
loss = kl_loss / pt + kl_prior
kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
kl_aux = sum_except_batch(kl_aux)
if self.recon_loss :
kl_aux += self.multi_criterion(log_x0_recon.exp(), x)
#kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)
kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
if self.adaptive_auxiliary_loss:
addition_loss_weight = (1-t/self.num_timesteps) + 1.0
else:
addition_loss_weight = 1.0
aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
loss += aux_loss
loss = -loss.sum() / (self.shape[0]*self.shape[1])
#loss += seg_loss
return -loss
def sample(self, voxel_input, intermediate=False):
device = self.log_alpha.device
self.shape = voxel_input.size()[1:]
uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)
log_z = self.log_sample_categorical(uniform_logits)
diffusion = []
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t, cond=voxel_input)
log_z = self.log_sample_categorical(log_model_prob)
if i%10 ==0:
diffusion.append(log_onehot_to_index(log_z))
result = log_onehot_to_index(log_z)
if intermediate :
return result, diffusion
else :
return result
================================================
FILE: layers/Voxel_Level/Gen_Diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Voxel_Level.gen_denoise import Denoise
from utils.loss import *
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8
def sum_except_batch(x, num_dims=1):
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def exists(x):
return x is not None
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def log_onehot_to_index(log_x):
return log_x.argmax(1)
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = np.clip(alphas, a_min=0.001, a_max=1.)
alphas = np.sqrt(alphas)
return alphas
class Diffusion(torch.nn.Module):
def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):
super(Diffusion, self).__init__()
#self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)
self.args = args
self.num_classes = self.args.num_classes
self.num_timesteps = self.args.diffusion_steps
self.recon_loss = self.args.recon_loss
self._denoise_fn = Denoise(args= self.args, num_class = self.num_classes)
self.auxiliary_loss_weight = auxiliary_loss_weight
self.adaptive_auxiliary_loss = adaptive_auxiliary_loss
self.multi_criterion = multi_criterion
alphas = cosine_beta_schedule(self.num_timesteps )
alphas = torch.tensor(alphas.astype('float64'))
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('log_alpha', log_alpha.float())
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())
self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
def device(self):
return self.denoise_fn.device
def multinomial_kl(self, log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
def q_pred_one_timestep(self, log_x_t, t):
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start, t):
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def predict_start(self, log_x_t, t):
x_t = log_onehot_to_index(log_x_t)
out = self._denoise_fn(x_t, t)
log_pred = F.log_softmax(out, dim=1)
return log_pred
def q_posterior(self, log_x_start, log_x_t, t):
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x, t):
log_x0_recon = self.predict_start(log_x, t)
log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
return log_model_pred, log_x0_recon
def log_sample_categorical(self, logits):
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + logits).argmax(dim=1)
log_sample = index_to_log_onehot(sample, self.num_classes)
return log_sample
def q_sample(self, log_x_start, t):
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
return log_sample
def kl_prior(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device).long()
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = Lt_sqrt / Lt_sqrt.sum()
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def forward(self, x, voxel_input):
b, device = x.size(0), x.device
self.shape = x.size()[1:]
t, pt = self.sample_time(b, device, 'importance')
log_x_start = index_to_log_onehot(x, self.num_classes)
log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)
log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
decoder_nll = -log_categorical(log_x_start, log_model_prob)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).float()
kl_loss = mask * decoder_nll + (1. - mask) * kl
if self.training:
Lt2 = kl_loss.pow(2)
Lt2_prev = self.Lt_history.gather(dim=0, index=t)
new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
loss = kl_loss / pt + kl_prior
kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
kl_aux = sum_except_batch(kl_aux)
'''if self.recon_loss :
kl_aux += self.multi_criterion(log_x0_recon.exp(), x)
kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)'''
kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
if self.adaptive_auxiliary_loss:
addition_loss_weight = (1-t/self.num_timesteps) + 1.0
else:
addition_loss_weight = 1.0
aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
loss += aux_loss
loss = -loss.sum() / (self.shape[0]*self.shape[1])
#loss += seg_loss
return -loss
def sample(self, voxel_input):
device = self.log_alpha.device
self.shape = voxel_input.size()[1:]
uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)
log_z = self.log_sample_categorical(uniform_logits)
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)
log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)
log_z = self.log_sample_categorical(log_model_prob)
result = log_onehot_to_index(log_z)
return result
================================================
FILE: layers/Voxel_Level/denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum
def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)
def conv1x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)
def conv1x3x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)
def conv3x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)
def conv3x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)
class Asymmetric_Residual_Block(nn.Module):
def __init__(self, in_filters, out_filters, time_filters=32*4):
super(Asymmetric_Residual_Block, self).__init__()
if in_filters<32 :
self.GroupNorm = nn.GroupNorm(16, in_filters)
self.bn0 = nn.GroupNorm(16, out_filters)
self.bn0_2 = nn.GroupNorm(16, out_filters)
self.bn1 = nn.GroupNorm(16, out_filters)
self.bn2 = nn.GroupNorm(16, out_filters)
else :
self.GroupNorm = nn.GroupNorm(32, in_filters)
self.bn0 = nn.GroupNorm(32, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.bn1 = nn.GroupNorm(32, out_filters)
self.bn2 = nn.GroupNorm(32, out_filters)
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
self.act1 = nn.LeakyReLU()
self.conv1_2 = conv3x1x3(out_filters, out_filters)
self.act1_2 = nn.LeakyReLU()
self.conv2 = conv3x1x3(in_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.conv3 = conv1x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
def forward(self, x, t):
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
x = self.GroupNorm(x) * (1 + scale) + shift
shortcut = self.conv1(x)
shortcut = self.act1(shortcut)
shortcut = self.bn0(shortcut)
shortcut = self.conv1_2(shortcut)
shortcut = self.act1_2(shortcut)
shortcut = self.bn0_2(shortcut)
resA = self.conv2(x)
resA = self.act2(resA)
resA = self.bn1(resA)
resA = self.conv3(resA)
resA = self.act3(resA)
resA = self.bn2(resA)
resA += shortcut
return resA
class DDCM(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
super(DDCM, self).__init__()
self.conv1 = conv3x1x1(in_filters, out_filters)
if in_filters<32 :
self.bn0 = nn.GroupNorm(16, out_filters)
self.bn0_2 = nn.GroupNorm(16, out_filters)
self.bn0_3 = nn.GroupNorm(16, out_filters)
else :
self.bn0 = nn.GroupNorm(32, out_filters)
self.bn0_2 = nn.GroupNorm(32, out_filters)
self.bn0_3 = nn.GroupNorm(32, out_filters)
self.act1 = nn.Sigmoid()
self.conv1_2 = conv1x3x1(in_filters, out_filters)
self.act1_2 = nn.Sigmoid()
self.conv1_3 = conv1x1x3(in_filters, out_filters)
self.act1_3 = nn.Sigmoid()
def forward(self, x):
shortcut = self.conv1(x)
shortcut = self.bn0(shortcut)
shortcut = self.act1(shortcut)
shortcut2 = self.conv1_2(x)
shortcut2 = self.bn0_2(shortcut2)
shortcut2 = self.act1_2(shortcut2)
shortcut3 = self.conv1_3(x)
shortcut3 = self.bn0_3(shortcut3)
shortcut3 = self.act1_3(shortcut3)
shortcut = shortcut + shortcut2 + shortcut3
shortcut = shortcut * x
return shortcut
def l2norm(t):
return F.normalize(t, dim = -1)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_qkv = conv1x1(dim, dim*3, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x):
b, c, h, w, Z = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class Cross_Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_q = conv1x1(dim, dim, stride=1)
self.to_k = conv1x1(dim, dim, stride=1)
self.to_v = conv1x1(dim, dim, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x, cond_x):
b, c, h, w, Z = x.shape
q = self.to_q(x)
k = self.to_k(cond_x)
v = self.to_v(cond_x)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class DownBlock(nn.Module):
def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_size=(3, 3, 3), stride=1,
pooling=True, height_pooling=False):
super(DownBlock, self).__init__()
self.pooling = pooling
self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters=time_filters)
if pooling:
if height_pooling:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)
else:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)
def forward(self, x, t):
resA = self.residual_block(x, t)
if self.pooling:
resB = self.pool(resA)
return resB, resA
else:
return resA
class UpBlock(nn.Module):
def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):
super(UpBlock, self).__init__()
# self.drop_out = drop_out
if out_filters<32 :
self.trans_bn = nn.GroupNorm(16, in_filters)
self.bn1 = nn.GroupNorm(16, out_filters)
self.bn2 = nn.GroupNorm(16, out_filters)
self.bn3 = nn.GroupNorm(16, out_filters)
else :
self.trans_bn = nn.GroupNorm(32, in_filters)
self.bn1 = nn.GroupNorm(32, out_filters)
self.bn2 = nn.GroupNorm(32, out_filters)
self.bn3 = nn.GroupNorm(32, out_filters)
self.trans_dilao = conv3x3x3(in_filters, in_filters)
self.trans_act = nn.LeakyReLU()
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
self.act1 = nn.LeakyReLU()
self.conv2 = conv3x1x3(out_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.conv3 = conv3x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
if height_pooling :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
else :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
def forward(self, x, residual, t):
upA = self.trans_dilao(x)
upA = self.trans_act(upA)
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
upA = self.trans_bn(upA) * (1 + scale) + shift
## upsample
upA = self.up_subm(upA)
upA += residual
upE = self.conv1(upA)
upE = self.act1(upE)
upE = self.bn1(upE)
upE = self.conv2(upE)
upE = self.act2(upE)
upE = self.bn2(upE)
upE = self.conv3(upE)
upE = self.act3(upE)
upE = self.bn3(upE)
return upE
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class Denoise(nn.Module):
def __init__(self, args, num_class = 11, init_size=32, discrete=True):
super(Denoise, self).__init__()
self.args = args
self.discrete = discrete
self.num_class = num_class
self.init_size = init_size
self.time_size = self.init_size*4
self.time_embed = nn.Sequential(
nn.Linear(init_size, self.time_size),
nn.SiLU(),
nn.Linear(self.time_size, self.time_size),
)
self.embedding = nn.Embedding(self.num_class, init_size)
self.conv_in = nn.Conv3d(init_size+1, init_size, kernel_size=1, stride=1)
self.A = Asymmetric_Residual_Block(init_size, init_size, time_filters=init_size*4)
self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)
self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True, time_filters=init_size*4)
self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)
self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False, time_filters=init_size*4)
self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)
self.attention = Attention(16 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)
self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)
self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=init_size*4)
self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)
self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)
self.DDCM = DDCM(2 * init_size, 2 * init_size)
self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x, x_cond, t):
x = self.embedding(x)
x = x.permute(0, 4, 1, 2, 3)
x_cond = x_cond.unsqueeze(1)
x = torch.cat([x, x_cond], dim=1)
x = self.conv_in(x)
t = self.time_embed(timestep_embedding(t, self.init_size))
x = self.A(x, t)
down1c, down1b = self.downBlock1(x, t)
down2c, down2b = self.downBlock2(down1c, t)
down3c, down3b = self.downBlock3(down2c, t)
down4c, down4b = self.downBlock4(down3c, t)
down4c = self.midBlock1(down4c, t)
down4c = self.attention(down4c)
down4c = self.midBlock2(down4c, t)
up4 = self.upBlock4(down4c, down4b, t)
up3 = self.upBlock3(up4, down3b, t)
up2 = self.upBlock2(up3, down2b, t)
up1 = self.upBlock1(up2, down1b, t)
up0 = self.DDCM(up1)
up = torch.cat((up1, up0), 1)
logits = self.logits(up)
return logits
================================================
FILE: layers/Voxel_Level/gen_denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum
def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)
def conv1x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)
def conv1x3x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)
def conv3x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)
def conv3x1x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)
class Asymmetric_Residual_Block(nn.Module):
def __init__(self, in_filters, out_filters, time_filters=128):
super(Asymmetric_Residual_Block, self).__init__()
if in_filters < 32 :
n_ng = in_filters
else : n_ng =32
self.GroupNorm = nn.GroupNorm(n_ng, in_filters)
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
if out_filters < 32 :
n_ng = out_filters
else : n_ng =32
self.bn0 = nn.GroupNorm(n_ng, out_filters)
self.act1 = nn.LeakyReLU()
self.conv1_2 = conv3x1x3(out_filters, out_filters)
self.bn0_2 = nn.GroupNorm(n_ng, out_filters)
self.act1_2 = nn.LeakyReLU()
self.conv2 = conv3x1x3(in_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.bn1 = nn.GroupNorm(n_ng, out_filters)
self.conv3 = conv1x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
self.bn2 = nn.GroupNorm(n_ng, out_filters)
def forward(self, x, t):
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
x = self.GroupNorm(x) * (1 + scale) + shift
shortcut = self.conv1(x)
shortcut = self.act1(shortcut)
shortcut = self.bn0(shortcut)
shortcut = self.conv1_2(shortcut)
shortcut = self.act1_2(shortcut)
shortcut = self.bn0_2(shortcut)
resA = self.conv2(x)
resA = self.act2(resA)
resA = self.bn1(resA)
resA = self.conv3(resA)
resA = self.act3(resA)
resA = self.bn2(resA)
resA += shortcut
return resA
class DDCM(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
super(DDCM, self).__init__()
self.conv1 = conv3x1x1(in_filters, out_filters)
if out_filters < 32 :
n_ng = out_filters
else : n_ng =32
self.bn0 = nn.GroupNorm(n_ng, out_filters)
self.act1 = nn.Sigmoid()
self.conv1_2 = conv1x3x1(in_filters, out_filters)
self.bn0_2 = nn.GroupNorm(n_ng, out_filters)
self.act1_2 = nn.Sigmoid()
self.conv1_3 = conv1x1x3(in_filters, out_filters)
self.bn0_3 = nn.GroupNorm(n_ng, out_filters)
self.act1_3 = nn.Sigmoid()
def forward(self, x):
shortcut = self.conv1(x)
shortcut = self.bn0(shortcut)
shortcut = self.act1(shortcut)
shortcut2 = self.conv1_2(x)
shortcut2 = self.bn0_2(shortcut2)
shortcut2 = self.act1_2(shortcut2)
shortcut3 = self.conv1_3(x)
shortcut3 = self.bn0_3(shortcut3)
shortcut3 = self.act1_3(shortcut3)
shortcut = shortcut + shortcut2 + shortcut3
shortcut = shortcut * x
return shortcut
def l2norm(t):
return F.normalize(t, dim = -1)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_qkv = conv1x1(dim, dim*3, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x):
b, c, h, w, Z = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class Cross_Attention(nn.Module):
def __init__(self, dim, heads = 4, scale = 10):
super().__init__()
self.scale = scale
self.heads = heads
self.to_q = conv1x1(dim, dim, stride=1)
self.to_k = conv1x1(dim, dim, stride=1)
self.to_v = conv1x1(dim, dim, stride=1)
self.to_out = conv1x1(dim, dim, stride=1)
def forward(self, x, cond_x):
b, c, h, w, Z = x.shape
q = self.to_q(x)
k = self.to_k(cond_x)
v = self.to_v(cond_x)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))
q, k = map(l2norm, (q, k))
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
return self.to_out(out)
class DownBlock(nn.Module):
def __init__(self, in_filters, out_filters, time_filters, kernel_size=(3, 3, 3), stride=1,
pooling=True, drop_out=True, height_pooling=False):
super(DownBlock, self).__init__()
self.pooling = pooling
self.drop_out = drop_out
self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters)
if pooling:
if height_pooling:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,
padding=1, bias=False)
else:
self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
padding=1, bias=False)
def forward(self, x, t):
resA = self.residual_block(x, t)
if self.pooling:
resB = self.pool(resA)
return resB, resA
else:
return resA
class UpBlock(nn.Module):
def __init__(self, in_filters, out_filters, height_pooling, time_filters):
super(UpBlock, self).__init__()
# self.drop_out = drop_out
self.trans_dilao = conv3x3x3(in_filters, in_filters)
self.trans_act = nn.LeakyReLU()
if in_filters < 32 :
n_ng = out_filters
else : n_ng =32
self.trans_bn = nn.GroupNorm(n_ng, in_filters)
self.time_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_filters, in_filters*2)
)
self.conv1 = conv1x3x3(in_filters, out_filters)
self.act1 = nn.LeakyReLU()
if out_filters < 32 :
n_ng = out_filters
else :n_ng = 32
self.bn1 = nn.GroupNorm(n_ng, out_filters)
self.conv2 = conv3x1x3(out_filters, out_filters)
self.act2 = nn.LeakyReLU()
self.bn2 = nn.GroupNorm(n_ng, out_filters)
self.conv3 = conv3x3x3(out_filters, out_filters)
self.act3 = nn.LeakyReLU()
self.bn3 = nn.GroupNorm(n_ng, out_filters)
if height_pooling :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
else :
self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
def forward(self, x, residual, t):
upA = self.trans_dilao(x)
upA = self.trans_act(upA)
t = self.time_layers(t)
while len(t.shape) < len(x.shape):
t = t[..., None]
scale, shift = torch.chunk(t, 2, dim=1)
upA = self.trans_bn(upA) * (1 + scale) + shift
## upsample
upA = self.up_subm(upA)
upA += residual
upE = self.conv1(upA)
upE = self.act1(upE)
upE = self.bn1(upE)
upE = self.conv2(upE)
upE = self.act2(upE)
upE = self.bn2(upE)
upE = self.conv3(upE)
upE = self.act3(upE)
upE = self.bn3(upE)
return upE
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class Denoise(nn.Module):
def __init__(self, args, num_class = 11, init_size=32, discrete=True):
super(Denoise, self).__init__()
self.args = args
self.discrete = discrete
self.num_class = num_class
self.init_size = init_size
self.time_size = init_size*4
self.time_embed = nn.Sequential(
nn.Linear(self.init_size, self.time_size),
nn.SiLU(),
nn.Linear(self.time_size, self.time_size),
)
self.embedding = nn.Embedding(self.num_class, self.init_size)
self.conv_in = nn.Conv3d(self.init_size, self.init_size, kernel_size=1, stride=1)
self.A = Asymmetric_Residual_Block(self.init_size, self.init_size, self.time_size)
self.downBlock1 = DownBlock(init_size, 2 * init_size, self.time_size, height_pooling=True)
self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, self.time_size, height_pooling=True)
self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, self.time_size, height_pooling=False)
self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, self.time_size, height_pooling=False)
self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)
self.attention = Attention(16 * init_size, 32)
self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)
self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=self.time_size)
self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=self.time_size)
self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)
self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)
self.DDCM = DDCM(2 * init_size, 2 * init_size)
self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x, t):
x = self.embedding(x)
x = x.permute(0, 4, 1, 2, 3)
x = self.conv_in(x)
t = self.time_embed(timestep_embedding(t, self.init_size))
x = self.A(x, t)
down1c, down1b = self.downBlock1(x, t)
down2c, down2b = self.downBlock2(down1c, t)
down3c, down3b = self.downBlock3(down2c, t)
down4c, down4b = self.downBlock4(down3c, t)
down4c = self.midBlock1(down4c, t)
down4c = self.attention(down4c)
down4c = self.midBlock2(down4c, t)
up4 = self.upBlock4(down4c, down4b, t)
up3 = self.upBlock3(up4, down3b, t)
up2 = self.upBlock2(up3, down2b, t)
up1 = self.upBlock1(up2, down1b, t)
up0 = self.DDCM(up1)
up = torch.cat((up1, up0), 1)
logits = self.logits(up)
return logits
================================================
FILE: layers/__init__.py
================================================
================================================
FILE: requirements.txt
================================================
numpy
torch
scipy
scikit-learn
matplotlib
tqdm
open3d
pyyaml
prettytable
tensorboard
numba
einops
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(
name="scene_scale_diffusion",
version="0.1",
author="Lee Jumin, Im Woobin, Lee Sebin, Yoon Sung-Eui",
author_email="",
description="Experiments in PyTorch",
long_description="",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)
================================================
FILE: simple_visualize.py
================================================
import os
import numpy as np
import open3d as o3d
import argparse
import yaml
def load_config(yaml_path):
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
return config["learning_map"], config["remap_color_map"]
def load_pointcloud(filepath, learning_map, color_map):
data = np.loadtxt(filepath, delimiter=' ')
if data.shape[1] < 4:
raise ValueError(f"Expected at least 4 columns (label + x y z), got shape {data.shape}")
raw_labels = data[:, 0].astype(int)
points = data[:, 1:4]
# Map raw labels → remapped labels → colors
remapped_labels = np.array([learning_map.get(int(l), 0) for l in raw_labels])
colors = np.array([color_map.get(int(l), [255, 255, 255]) for l in remapped_labels]) / 255.0
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcd
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--file', default='result_for_l_gen/Completion/result_0.txt',
help='Path to the point cloud .txt file')
parser.add_argument('--config', default='datasets/carla.yaml',
help='Path to Carla YAML config file')
args = parser.parse_args()
if not os.path.exists(args.file):
raise FileNotFoundError(f"Point cloud file not found: {args.file}")
if not os.path.exists(args.config):
raise FileNotFoundError(f"YAML config file not found: {args.config}")
learning_map, color_map = load_config(args.config)
pcd = load_pointcloud(args.file, learning_map, color_map)
o3d.visualization.draw([pcd])
if __name__ == "__main__":
main()
================================================
FILE: train.py
================================================
from dataclasses import astuple
import torch
import argparse
import numpy as np
import os
import pickle
import torch
import torch.nn.functional as F
import yaml
from prettytable import PrettyTable
from torch.utils.tensorboard import SummaryWriter
from utils.tables import *
from utils.dicts import clean_dict
from utils.loss import lovasz_softmax
class Experiment(object):
no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']
def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch,
train_loader, eval_loader, test_loader, train_sampler,
log_path, eval_every, check_every):
# Objects
self.model = model
self.loss_fun = torch.nn.CrossEntropyLoss(ignore_index=0)
self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch
# Paths
self.log_path = log_path
if args.dataset =='carla':
config_file = os.path.join('./datasets/carla.yaml')
carla_config = yaml.safe_load(open(config_file, 'r'))
self.color_map = carla_config["remap_color_map"]
self.remap = None
LABEL_TO_NAMES = carla_config["label_to_names"]
self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))
# Intervals
self.eval_every, self.check_every = eval_every, check_every
# Initialize
self.current_epoch = 0
self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}
self.eval_epochs = []
self.completion_epochs = []
# Store data loaders
self.train_loader, self.eval_loader, self.test_loader, self.train_sampler = train_loader, eval_loader, test_loader, train_sampler
# Store args
create_folders(args)
save_args(args)
self.args = args
# Init logging
args_dict = clean_dict(vars(args), keys=self.no_log_keys)
if args.log_tb:
self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0)
def run(self, epochs):
if self.args.resume:
self.resume()
for epoch in range(self.current_epoch, epochs):
# Train
train_dict = self.train_fn(epoch)
self.log_metrics(train_dict, self.train_metrics)
# Checkpoint
self.current_epoch += 1
if (epoch+1) % self.check_every == 0:
self.checkpoint_save(epoch)
# Eval
if (epoch+1) % self.eval_every == 0:
eval_dict = self.eval_fn(epoch)
self.log_metrics(eval_dict, self.eval_metrics)
self.eval_epochs.append(epoch)
else:
eval_dict = None
if (epoch+1) % self.args.completion_epoch == 0:
ssc_dict, miou, seg_dict, seg_miou = self.sample()
self.log_metrics(ssc_dict, self.ssc_metrics)
self.log_metrics(seg_dict, self.ssc_metrics)
self.completion_epochs.append(epoch)
else :
ssc_dict, seg_dict = None, None
# Log
#self.save_metrics()
if self.args.log_tb:
for metric_name, metric_value in train_dict.items():
self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)
if eval_dict:
for metric_name, metric_value in eval_dict.items():
self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)
if ssc_dict:
for metric_name, metric_value in ssc_dict.items():
self.writer.add_scalar('SSC/{}'.format(metric_name), metric_value, global_step=epoch+1)
self.writer.add_text("SSC_mIoU", get_miou_table(self.args, self.label_to_names, miou).get_html_string(), global_step=epoch+1)
for metric_name, metric_value in seg_dict.items():
self.writer.add_scalar('Seg/{}'.format(metric_name), metric_value, global_step=epoch+1)
self.writer.add_text("Seg_mIoU", get_miou_table(self.args, self.label_to_names, seg_miou).get_html_string(), global_step=epoch+1)
def train_fn(self, epoch):
self.model.train()
loss_sum = 0.0
loss_count = 0
if self.args.distribution :
self.train_sampler.set_epoch(epoch)
for voxel_input, output, counts in self.train_loader:
self.optimizer.zero_grad()
voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
output = torch.from_numpy(np.asarray(output)).long().cuda()
if self.args.distribution:
loss = self.model.module(output, voxel_input)
else :
loss = self.model(output, voxel_input)
loss.backward()
if self.args.clip_value: torch.nn.utils.clip_grad_value_(self.model.parameters(), self.args.clip_value)
if self.args.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_norm)
self.optimizer.step()
if self.scheduler_iter: self.scheduler_iter.step()
loss_sum += loss.detach().cpu().item() * len(output)
loss_count += len(output)
print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\r')
print('')
if self.scheduler_epoch: self.scheduler_epoch.step()
return {'loss': loss_sum/loss_count}
def eval_fn(self, epoch):
self.model.eval()
with torch.no_grad():
loss_sum = 0.0
loss_count = 0
for voxel_input, output, counts in self.eval_loader:
voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
output = torch.from_numpy(np.asarray(output)).long().cuda()
if self.args.distribution:
loss = self.model.module(output, voxel_input)
else :
loss = self.model(output, voxel_input)
loss_sum += loss.detach().cpu().item() * len(output)
loss_count += len(output)
print('Train evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
print('')
return {'loss': loss_sum/loss_count}
def sample(self):
self.model.eval()
with torch.no_grad():
TP, FP, TN, FN, num_correct, num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
s_TP, s_FP, s_TN, s_FN, s_num_correct, s_num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
all_intersections, all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6
s_all_intersections, s_all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6
if self.args.dataset == 'carla':
dataloader = self.test_loader
else :
dataloader = self.eval_loader
for iterate, (voxel_input, output, counts) in enumerate(dataloader):
if len(voxel_input) == self.args.batch_size :
voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
output = torch.from_numpy(np.asarray(output)).long().cuda()
invalid = torch.from_numpy(np.asarray(counts)).cuda()
if self.args.mode == 'l_vae':
if self.args.distribution:
recons = self.model.module.sample(output)
else :
recons = self.model.sample(output)
else :
if self.args.distribution:
recons = self.model.module.sample(voxel_input)
else :
recons = self.model.sample(voxel_input)
visualization(self.args, recons, voxel_input, output, invalid, iteration = iterate)
correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union = get_result(self.args, invalid, output, recons)
all_intersections += intersection
all_unions += union
num_correct += correct
num_total += total
TP += pred_TP
FP += pred_FP
TN += pred_TN
FN += pred_FN
s_correct, s_total, s_pred_TP, s_pred_FP, s_pred_TN, s_pred_FN, s_intersection, s_union = get_result(self.args, voxel_input, output, recons, SSC=False)
s_all_intersections += s_intersection
s_all_unions += s_union
s_num_correct += s_correct
s_num_total += s_total
s_TP += s_pred_TP
s_FP += s_pred_FP
s_TN += s_pred_TN
s_FN += s_pred_FN
iou, miou = print_result(self.args, self.label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN)
s_iou, seg_miou = print_result(self.args, self.label_to_names, s_num_correct, s_num_total, s_all_intersections, s_all_unions, s_TP, s_FP, s_FN, SSC=False)
return {"IoU" : iou, "mIoU": np.mean(miou)*100 }, miou, {"IoU" : s_iou, "mIoU": np.mean(seg_miou)*100 }, seg_miou
def resume(self):
self.checkpoint_load(self.args.resume_path)
for epoch in range(self.current_epoch):
train_dict = {}
for metric_name, metric_values in self.train_metrics.items():
train_dict[metric_name] = metric_values[epoch]
if epoch in self.eval_epochs:
eval_dict = {}
for metric_name, metric_values in self.eval_metrics.items():
eval_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]
else:
eval_dict = None
if epoch in self.completion_epochs:
sample_dict = {}
for metric_name, metric_values in self.eval_metrics.items():
sample_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]
else:
sample_dict = None
for metric_name, metric_value in train_dict.items():
self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)
if eval_dict:
for metric_name, metric_value in eval_dict.items():
self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)
if sample_dict:
for metric_name, metric_value in sample_dict.items():
self.writer.add_scalar('sample/{}'.format(metric_name), metric_value, global_step=epoch+1)
def log_metrics(self, dict, type):
if len(type)==0:
for metric_name, metric_value in dict.items():
type[metric_name] = [metric_value]
else:
for metric_name, metric_value in dict.items():
type[metric_name].append(metric_value)
def save_metrics(self):
# Save metrics
with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f:
pickle.dump(self.train_metrics, f)
with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f:
pickle.dump(self.eval_metrics, f)
# Save metrics table
metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2)))
with open(os.path.join(self.log_path,'metrics_train.txt'), "w") as f:
f.write(str(metric_table))
metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs])
with open(os.path.join(self.log_path,'metrics_eval.txt'), "w") as f:
f.write(str(metric_table))
def checkpoint_save(self, epoch):
if self.args.distribution:
checkpoint = {'current_epoch': self.current_epoch,
'train_metrics': self.train_metrics,
'eval_metrics': self.eval_metrics,
'eval_epochs': self.eval_epochs,
'optimizer': self.optimizer.state_dict(),
'model': self.model.module.state_dict(),
'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,
'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}
else :
checkpoint = {'current_epoch': self.current_epoch,
'train_metrics': self.train_metrics,
'eval_metrics': self.eval_metrics,
'eval_epochs': self.eval_epochs,
'optimizer': self.optimizer.state_dict(),
'model': self.model.state_dict(),
'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,
'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}
epoch_name = 'epoch{}.tar'.format(epoch)
torch.save(checkpoint, os.path.join(self.log_path, epoch_name))
def checkpoint_load(self, resume_path):
checkpoint = torch.load(resume_path)
if self.args.distribution:
self.model.module.load_state_dict(checkpoint['model'])
else :
self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])
if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])
self.current_epoch = checkpoint['current_epoch']
self.train_metrics = checkpoint['train_metrics']
self.eval_metrics = checkpoint['eval_metrics']
self.eval_epochs = checkpoint['eval_epochs']
================================================
FILE: utils/cuda.py
================================================
import os
import os
import torch
from torch import distributed as dist
from torch import multiprocessing as mp
import utils.dicts as dist_fn
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
world_size = n_machine * n_gpu_per_machine
if world_size > 1:
# if "OMP_NUM_THREADS" not in os.environ:
# os.environ["OMP_NUM_THREADS"] = "1"
if dist_url == "auto":
if n_machine != 1:
raise ValueError('dist_url="auto" not supported in multi-machine jobs')
port = find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if n_machine > 1 and dist_url.startswith("file://"):
raise ValueError(
"file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
distributed_worker,
nprocs=n_gpu_per_machine,
args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
daemon=False,
)
else:
local_rank = 0
fn(local_rank, *args)
def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args):
if not torch.cuda.is_available():
raise OSError("CUDA is not available. Please check your environments")
global_rank = machine_rank * n_gpu_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception:
raise OSError("failed to initialize NCCL groups")
dist_fn.synchronize()
if n_gpu_per_machine > torch.cuda.device_count():
raise ValueError(
f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
)
torch.cuda.set_device(local_rank)
if dist_fn.LOCAL_PROCESS_GROUP is not None:
raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
n_machine = world_size // n_gpu_per_machine
for i in range(n_machine):
ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
dist_fn.LOCAL_PROCESS_GROUP = pg
fn(local_rank, *args)
def set_cuda_vd(gpu_ids, verbose=True):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(id) for id in gpu_ids)
if verbose: print("CUDA_VISIBLE_DEVICES = {}",format(os.environ["CUDA_VISIBLE_DEVICES"]))
================================================
FILE: utils/dicts.py
================================================
import copy
import math
import pickle
import torch
from torch import distributed as dist
from torch.utils import data
LOCAL_PROCESS_GROUP = None
def is_primary():
return get_rank() == 0
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
if LOCAL_PROCESS_GROUP is None:
raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")
return dist.get_rank(group=LOCAL_PROCESS_GROUP)
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def is_distributed():
raise RuntimeError('Please debug this function!')
return get_world_size() > 1
def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
world_size = get_world_size()
if world_size == 1:
return tensor
dist.all_reduce(tensor, op=op, async_op=async_op)
return tensor
def all_gather(data):
world_size = get_world_size()
if world_size == 1:
return [data]
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
local_size = torch.IntTensor([tensor.numel()]).to("cuda")
size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), 0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
keys = []
values = []
for k in sorted(input_dict.keys()):
keys.append(k)
values.append(input_dict[k])
values = torch.stack(values, 0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
values /= world_size
reduced_dict = {k: v for k, v in zip(keys, values)}
return reduced_dict
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def clean_dict(d, keys):
d2 = copy.deepcopy(d)
for key in keys:
if key in d2:
del d2[key]
return d2
================================================
FILE: utils/intermediate_vis.py
================================================
from dataclasses import astuple
import torch
import argparse
import numpy as np
import os
import pickle
import torch
import torch.nn.functional as F
import yaml
from prettytable import PrettyTable
from torch.utils.tensorboard import SummaryWriter
from utils.tables import *
from utils.dicts import clean_dict
from utils.loss import lovasz_softmax
class Vis_iter(object):
no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']
def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader,log_path):
# Objects
self.model = model
self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch
# Paths
self.log_path = log_path
if args.dataset =='kitti':
config_file = os.path.join('/home/jumin/multinomial_diffusion/datasets/semantic_kitti.yaml')
kitti_config = yaml.safe_load(open(config_file, 'r'))
self.remap = kitti_config['learning_map_inv']
self.color_map = kitti_config["color_map"]
label = kitti_config['labels']
map_index = np.asarray([self.remap[i] for i in range(20)])
self.label_to_names = np.asarray([label[map_i] for map_i in map_index])
elif args.dataset =='carla':
base_dir = os.path.dirname(__file__)
config_file = os.path.join(base_dir, '../datasets/carla.yaml')
carla_config = yaml.safe_load(open(config_file, 'r'))
self.color_map = carla_config["remap_color_map"]
self.remap = None
LABEL_TO_NAMES = carla_config["label_to_names"]
self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))
# Initialize
self.current_epoch = 0
self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}
self.eval_epochs = []
self.completion_epochs = []
# Store data loaders
self.test_loader = test_loader
# Store args
create_folders(args)
save_args(args)
self.args = args
# Init logging
args_dict = clean_dict(vars(args), keys=self.no_log_keys)
if args.log_tb:
self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0)
def run(self, epochs):
self.checkpoint_load(self.args.resume_path)
for epoch in range(self.current_epoch, epochs):
self.sample()
def sample(self):
self.model.eval()
with torch.no_grad():
for iterate, (voxel_input, output, counts) in enumerate(self.test_loader):
voxel_input = torch.from_numpy(np.asarray(voxel_input)).squeeze(1).cuda()
output = torch.from_numpy(np.asarray(output)).long().cuda()
_, intermediate = self.model.module.sample(voxel_input, intermediate=True)
inter_vis(self.args, intermediate)
break
def checkpoint_load(self, resume_path):
checkpoint = torch.load(resume_path)
if self.args.distribution:
self.model.module.load_state_dict(checkpoint['model'])
else :
self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])
if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])
self.current_epoch = checkpoint['current_epoch']
self.train_metrics = checkpoint['train_metrics']
self.eval_metrics = checkpoint['eval_metrics']
self.eval_epochs = checkpoint['eval_epochs']
================================================
FILE: utils/loss.py
================================================
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
# -*- coding:utf-8 -*-
# author: Xinge
def dice_coef(y_true, y_pred, smooth=1e-6):
y_true_f = y_true.view(-1)
y_pred_f = y_pred.view(-1)
intersection = (y_true_f * y_pred_f).sum()
return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)
def dice_coef_multilabel(y_true, y_pred, numLabels=11):
dice=0
for index in range(1, numLabels):
dice += dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:])
return (numLabels-1) - dice
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes is 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
elif probas.dim() == 5:
#3D segmentation
B, C, L, H, W = probas.size()
probas = probas.contiguous().view(B, C, L, H*W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
return x != x
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
================================================
FILE: utils/multistep.py
================================================
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import _LRScheduler
class LinearWarmupScheduler(_LRScheduler):
""" Linearly warm-up (increasing) learning rate, starting from zero.
Args:
optimizer (Optimizer): Wrapped optimizer.
total_epoch: target learning rate is reached at total_epoch.
"""
def __init__(self, optimizer, total_epoch, last_epoch=-1):
self.total_epoch = total_epoch
super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * min(1, (self.last_epoch / self.total_epoch)) for base_lr in self.base_lrs]
optim_choices = {'sgd', 'adam', 'adamax'}
def get_optim(args, model):
assert args.optimizer in optim_choices
if args.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
elif args.optimizer == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))
elif args.optimizer == 'adamax':
optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))
if args.warmup is not None:
scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup)
else:
scheduler_iter = None
if len(args.milestones)>0:
scheduler_epoch = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
else:
scheduler_epoch = None
return optimizer, scheduler_iter, scheduler_epoch
================================================
FILE: utils/tables.py
================================================
from prettytable import PrettyTable
import torch
import os
import pickle
import numpy as np
import torch.nn.functional as F
import open3d as o3d
def get_args_table(args_dict):
table = PrettyTable(['Arg', 'Value'])
for arg, val in args_dict.items():
table.add_row([arg, val])
return table
def get_miou_table(args, label_to_names, miou):
table = PrettyTable(['Label', 'mIoU'])
for i in range(args.num_classes):
table.add_row([label_to_names[i], 100 * miou[i]])
return table
def get_metric_table(metric_dict, epochs):
table = PrettyTable()
table.add_column('Epoch', epochs)
if len(metric_dict)>0:
for metric_name, metric_values in metric_dict.items():
table.add_column(metric_name, metric_values)
return table
def create_folders(args):
# Create log folder
os.makedirs(args.log_path, exist_ok=True)
os.makedirs(args.log_path+'/Completion', exist_ok=True)
os.makedirs(args.log_path+'/Input', exist_ok=True)
os.makedirs(args.log_path+'/Output', exist_ok=True)
os.makedirs(args.log_path+'/Invalid', exist_ok=True)
print("Storing logs in:", args.log_path)
def inter_vis(args, recons):
for r in range(len(recons)):
for batch, samples_i in enumerate(recons[r]):
color_index = []
for i in range(1, args.num_classes):
index = torch.nonzero(samples_i == i ,as_tuple=False)
color_index.append(F.pad(index,(1,0),'constant',value = i))
colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()
np.savetxt('/home/jumin/multinomial_diffusion/Result/Condition/Completion/iteration/batch{}_{}.txt'.format(batch, r), colors_indexs)
def visualization(args, recons, input_data, output, invalid, iteration):
for batch, (samples_i, input_i, output_i, invalid_i) in enumerate(zip(recons, input_data, output, invalid)):
color_index = []
output_index = []
input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()
if args.dataset =='carla':
invalid_points = torch.nonzero(invalid_i == 0, as_tuple=False).cpu().numpy()
elif args.dataset =='kitti':
invalid_points = torch.nonzero(invalid_i == 1, as_tuple=False).cpu().numpy()
for i in range(1, args.num_classes):
index = torch.nonzero(samples_i == i ,as_tuple=False)
out_color = torch.nonzero(output_i == i, as_tuple=False)
color_index.append(F.pad(index,(1,0),'constant',value = i))
output_index.append(F.pad(out_color,(1,0),'constant',value=i))
colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()
out_indexs = torch.cat(output_index, dim = 0).cpu().numpy()
np.savetxt(args.log_path+'/Completion/result_{}.txt'.format((iteration * args.batch_size) + batch), colors_indexs)
'''np.savetxt(args.log_path+'/Input/input_{}.txt'.format((iteration * args.batch_size) + batch), input_points)
np.savetxt(args.log_path+'/Invalid/invalid_{}.txt'.format((iteration * args.batch_size) + batch), invalid_points)
np.savetxt(args.log_path+'/Output/gt_{}.txt'.format((iteration * args.batch_size) + batch), out_indexs)'''
def completion_vis(args, input_p, recons):
for batch, (recon_i, input_i) in enumerate(zip(recons, input_p)):
recon_points = torch.nonzero(recon_i == 1, as_tuple=False).cpu().numpy()
input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()
np.savetxt(args.log_path+'/Completion/completion_{}.txt'.format(batch), recon_points)
np.savetxt(args.log_path+'/Input/input_{}.txt'.format(batch), input_points)
def iou_one_frame(pred, target, n_classes=23):
pred = pred.view(-1).detach().cpu().numpy()
target = target.view(-1).detach().cpu().numpy()
intersection = np.zeros(n_classes)
union = np.zeros(n_classes)
for cls in range(n_classes):
intersection[cls] = np.sum((pred == cls) & (target == cls))
union[cls] = np.sum((pred == cls) | (target == cls))
return intersection, union
def get_result(args, for_mask, output, preds, SSC=True):
for_mask = for_mask.contiguous().view(-1)
output = output.contiguous().view(-1)
preds = preds.contiguous().view(-1)
if SSC :
if args.dataset == 'kitti':
mask = for_mask == 0
elif args.dataset== 'carla':
mask = for_mask > 0
else :
mask = for_mask == 1
output_masked = output[mask]
iou_output_masked = output_masked.cpu().numpy()
iou_output_masked[iou_output_masked != 0] = 1
preds_masked = preds[mask]
iou_preds_masked = preds_masked.cpu().numpy()
iou_preds_masked[iou_preds_masked != 0] = 1
# I, U for a frame
correct = np.sum(output_masked.cpu().numpy() == preds_masked.cpu().numpy())
total = preds_masked.shape[0]
pred_TP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 1))
pred_FP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 0))
pred_TN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 0))
pred_FN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 1))
intersection, union = iou_one_frame(preds_masked, output_masked, n_classes=args.num_classes)
return correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union
def save_args(args):
# Save args
with open(os.path.join(args.log_path, 'args.pickle'), "wb") as f:
pickle.dump(args, f)
# Save args table
args_table = get_args_table(vars(args))
with open(os.path.join(args.log_path,'args_table.txt'), "w") as f:
f.write(str(args_table))
def print_completion(num_correct, num_total, TP, FP, FN):
print("\n=========================================\n")
accuracy = num_correct/num_total
print("\nAccuracy : ", accuracy)
precision = 100 * TP / (TP + FP)
recall = 100 * TP / (TP + FN)
iou = 100 * TP / (TP + FP + FN)
print("\nCompleteness")
print("precision:", precision)
print("recall:", recall)
print("iou:", iou)
print("\n=========================================\n")
return iou
def print_result(args, label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN, SSC=True):
if SSC :
print("\n========== Semantic Scene Completion =============\n")
else :
print("\n============ Semantic Segmentation ===============\n")
accuracy = num_correct/num_total
print("\nAccuracy : ", accuracy)
precision = 100 * TP / (TP + FP)
recall = 100 * TP / (TP + FN)
iou = 100 * TP / (TP + FP + FN)
print("\nCompleteness")
print("precision:", precision)
print("recall:", recall)
print("iou:", iou)
print("\nSemantic IoU Per Class")
miou = all_intersections / all_unions
for i in range(args.num_classes):
print(label_to_names[i], ':', 100 * miou[i])
print("\n====================================================\n")
return iou, miou
================================================
FILE: visualization.py
================================================
import os
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
import argparse
import numpy as np
import yaml
import struct
parser = argparse.ArgumentParser()
parser.add_argument('--M', default='scene-scale-diffusion') # VQVAE, multinomial_diffusion
parser.add_argument('--Driver', default='D')
parser.add_argument('--frame', default='0')
parser.add_argument('--file', default='result_')
parser.add_argument('--folder', default='Completion')
parser.add_argument('--model', default='image_init8_concat_att')
parser.add_argument('--name', default='Semantic Scene Completion')
parser.add_argument('--invalid', default = False)
class SpheresApp:
MENU_SCENE = 1
MENU_BEFORE = 2
MENU_QUIT = 3
def __init__(self, opt):
self._id = 0
self.opt = opt
self.window = gui.Application.instance.create_window("Semantic Scene Completion", 1500, 1000)
self.scene = gui.SceneWidget()
self.scene.scene = rendering.Open3DScene(self.window.renderer)
self.scene.scene.set_background([1, 1, 1, 1])
self.scene.scene.scene.set_sun_light(
[-0.577, 0.577, -0.577], # direction
[1, 1, 1], # color
60000) # intensity
self.scene.scene.scene.enable_sun_light(True)
bbox = o3d.geometry.AxisAlignedBoundingBox([64, 64, -60], [64, 64, 60])
self.scene.setup_camera(60, bbox, [0, 0, 1])
self.window.add_child(self.scene)
if gui.Application.instance.menubar is None:
debug_menu = gui.Menu()
debug_menu.add_item("Next Scene", SpheresApp.MENU_SCENE)
debug_menu.add_separator()
debug_menu.add_item("Before Scene", SpheresApp.MENU_BEFORE)
debug_menu.add_separator()
debug_menu.add_item("Quit", SpheresApp.MENU_QUIT)
menu = gui.Menu()
menu.add_menu("SSC", debug_menu)
gui.Application.instance.menubar = menu
# The menubar is global, but we need to connect the menu items to the
# window, so that the window can call the appropriate function when the menu item is activated.
self.window.set_on_menu_item_activated(SpheresApp.MENU_SCENE,self._on_menu_scene)
self.window.set_on_menu_item_activated(SpheresApp.MENU_QUIT,self._on_menu_quit)
self.window.set_on_menu_item_activated(SpheresApp.MENU_BEFORE,self._on_menu_before)
def _on_menu_before(self):
self._id -= 1
mat = rendering.MaterialRecord()
mat.shader = "defaultLit"
if self.opt.file == 'input_':
points = get_input(self.opt)
else :
points, colors = get_voxel(self.opt)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if (self.opt.file != 'input_'):
pcd.colors = o3d.utility.Vector3dVector(colors/255)
self.scene.scene.clear_geometry()
voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)
self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat)
print(self.opt.frame)
self.opt.frame = str(int(self.opt.frame)-1)
def _on_menu_quit(self):
gui.Application.instance.quit()
def _on_menu_scene(self):
self._id += 1
mat = rendering.MaterialRecord()
mat.shader = "defaultLit"
if self.opt.file == 'input_':
points = get_input(self.opt)
else :
points, colors = get_voxel(self.opt)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if (self.opt.file != 'input_'):
pcd.colors = o3d.utility.Vector3dVector(colors/255)
self.scene.scene.clear_geometry()
voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)
self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat)
print(self.opt.frame)
self.opt.frame = str(int(self.opt.frame)+1)
def get_voxel(opt):
if opt.invalid :
invalid_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/Invalid/invalid_'+ opt.frame +'.txt'
invalid_points = np.loadtxt(invalid_path, delimiter=' ')
invalid_colors = np.full(len(invalid_points,), 0)
point_cloud_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/' + opt.folder +'/'+ opt.file + opt.frame +'.txt'
points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
points = points_colors[:, 1:]
colors = points_colors[:, 0]
points = np.concatenate((invalid_points, points), axis=0)
colors = np.concatenate((invalid_colors, colors), axis=0)
points, index = np.unique(points, return_index=True, axis=0)
colors = colors[index, ...]
else :
point_cloud_path = 'C:/Users/jumin/Dataset/result_319_110.txt'
points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
points = points_colors[:, 1:]
colors = points_colors[:, 0]
if opt.dataset == 'carla' :
base_dir = os.path.dirname(__file__)
config_file = os.path.join(base_dir, 'datasets/carla.yaml')
config = yaml.safe_load(open(config_file, 'r'))
color_map = config["remap_color_map"]
color = np.asarray([color_map[c] for c in colors])
return points, color
def get_input(opt):
point_cloud_path=opt.Driver+':/'+opt.M+'/result/' + opt.model +'/Invalid/invalid_' + opt.frame +'.txt'
points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
points = points_colors
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
return points
def main(opt):
gui.Application.instance.initialize()
SpheresApp(opt)
gui.Application.instance.run()
if __name__ == "__main__":
opt = parser.parse_args()
main(opt)
gitextract_ghos05eb/ ├── .gitignore ├── LICENSE.txt ├── README.md ├── SSC_train.py ├── __init__.py ├── datasets/ │ ├── carla.yaml │ ├── carla_dataset.py │ └── data.py ├── layers/ │ ├── Ablation/ │ │ └── wo_diffusion.py │ ├── Latent_Level/ │ │ ├── stage1/ │ │ │ ├── model.py │ │ │ ├── vector_quantizer.py │ │ │ └── vqvae.py │ │ └── stage2/ │ │ ├── Gen_diffusion.py │ │ └── gen_denoise.py │ ├── Voxel_Level/ │ │ ├── Con_Diffusion.py │ │ ├── Gen_Diffusion.py │ │ ├── denoise.py │ │ └── gen_denoise.py │ └── __init__.py ├── requirements.txt ├── setup.py ├── simple_visualize.py ├── train.py ├── utils/ │ ├── cuda.py │ ├── dicts.py │ ├── intermediate_vis.py │ ├── loss.py │ ├── multistep.py │ └── tables.py └── visualization.py
SYMBOL INDEX (295 symbols across 22 files)
FILE: SSC_train.py
function get_args (line 28) | def get_args():
function main (line 92) | def main():
function start (line 112) | def start(local_rank, args):
FILE: datasets/carla_dataset.py
class CarlaDataset (line 21) | class CarlaDataset(Dataset):
method __init__ (line 25) | def __init__(self, directory,
method __len__ (line 104) | def __len__(self):
method collate_fn (line 107) | def collate_fn(self, data):
method points_to_voxels (line 113) | def points_to_voxels(self, voxel_grid, points, t_i):
method get_pose (line 129) | def get_pose(self, idx):
method __getitem__ (line 135) | def __getitem__(self, idx):
method find_horizon (line 193) | def find_horizon(self, idx):
FILE: datasets/data.py
function get_data_id (line 11) | def get_data_id(args):
function get_class_weights (line 14) | def get_class_weights(freq):
function get_data (line 23) | def get_data(args):
FILE: layers/Ablation/wo_diffusion.py
class wo_diff (line 7) | class wo_diff(torch.nn.Module):
method __init__ (line 8) | def __init__(self, args, multi_criterion) -> None:
method device (line 22) | def device(self):
method forward (line 25) | def forward(self, x, input_ten):
method sample (line 31) | def sample(self, x):
FILE: layers/Latent_Level/stage1/model.py
function conv3x3x3 (line 10) | def conv3x3x3(in_planes, out_planes, stride=1):
function conv1x3x3 (line 13) | def conv1x3x3(in_planes, out_planes, stride=1):
function conv1x1x3 (line 16) | def conv1x1x3(in_planes, out_planes, stride=1):
function conv1x3x1 (line 19) | def conv1x3x1(in_planes, out_planes, stride=1):
function conv3x1x1 (line 22) | def conv3x1x1(in_planes, out_planes, stride=1):
function conv3x1x3 (line 25) | def conv3x1x3(in_planes, out_planes, stride=1):
function conv1x1 (line 28) | def conv1x1(in_planes, out_planes, stride=1):
class Asymmetric_Residual_Block (line 32) | class Asymmetric_Residual_Block(nn.Module):
method __init__ (line 33) | def __init__(self, in_filters, out_filters):
method forward (line 60) | def forward(self, x):
class DownBlock (line 81) | class DownBlock(nn.Module):
method __init__ (line 82) | def __init__(self, in_filters, out_filters, pooling=True, drop_out=Tru...
method forward (line 93) | def forward(self, x):
class UpBlock (line 102) | class UpBlock(nn.Module):
method __init__ (line 103) | def __init__(self, in_filters, out_filters, height_pooling):
method forward (line 135) | def forward(self, x, skip=False):
class DDCM (line 159) | class DDCM(nn.Module):
method __init__ (line 160) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
method forward (line 180) | def forward(self, x):
function l2norm (line 197) | def l2norm(t):
class Attention (line 200) | class Attention(nn.Module):
method __init__ (line 201) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 208) | def forward(self, x):
class C_Encoder (line 221) | class C_Encoder(nn.Module):
method __init__ (line 222) | def __init__(self, args, nclasses=20, init_size=16, l_size='882', att...
method forward (line 256) | def forward(self, x, out_conv=True):
class C_Decoder (line 278) | class C_Decoder(nn.Module):
method __init__ (line 279) | def __init__(self, args, nclasses=20, init_size=16, l_size='882', atte...
method forward (line 309) | def forward(self, x, in_conv=True):
class Completion (line 333) | class Completion(nn.Module):
method __init__ (line 334) | def __init__(self, args, num_class = 11, init_size=32):
method forward (line 362) | def forward(self, x):
FILE: layers/Latent_Level/stage1/vector_quantizer.py
class VectorQuantizer (line 5) | class VectorQuantizer(nn.Module):
method __init__ (line 7) | def __init__(self,
method forward (line 19) | def forward(self, z: torch.tensor, point=False) -> torch.tensor: # lat...
method codebook_to_embedding (line 46) | def codebook_to_embedding(self, encoding_inds, latents_shape): # laten...
FILE: layers/Latent_Level/stage1/vqvae.py
class vqvae (line 10) | class vqvae(torch.nn.Module):
method __init__ (line 11) | def __init__(self, args, multi_criterion) -> None:
method device (line 28) | def device(self):
method encode (line 31) | def encode(self, x):
method vector_quantize (line 36) | def vector_quantize(self, latent):
method coodbook (line 40) | def coodbook(self,quantized_latent_ind, latents_shape):
method decode (line 44) | def decode(self, quantized_latent):
method forward (line 49) | def forward(self, x, input_ten):
method sample (line 58) | def sample(self, x):
FILE: layers/Latent_Level/stage2/Gen_diffusion.py
function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
function log_1_min_a (line 18) | def log_1_min_a(a):
function log_add_exp (line 22) | def log_add_exp(a, b):
function exists (line 27) | def exists(x):
function extract (line 31) | def extract(a, t, x_shape):
function default (line 37) | def default(val, d):
function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
function log_onehot_to_index (line 58) | def log_onehot_to_index(log_x):
function cosine_beta_schedule (line 62) | def cosine_beta_schedule(timesteps, s = 0.008):
class latent_diffusion (line 78) | class latent_diffusion(torch.nn.Module):
method __init__ (line 79) | def __init__(self, args, VAE_DENSE, multi_criterion,
method device (line 115) | def device(self):
method multinomial_kl (line 118) | def multinomial_kl(self, log_prob1, log_prob2):
method q_pred_one_timestep (line 122) | def q_pred_one_timestep(self, log_x_t, t):
method q_pred (line 135) | def q_pred(self, log_x_start, t):
method predict_start (line 146) | def predict_start(self, log_x_t, t):
method q_posterior (line 158) | def q_posterior(self, log_x_start, log_x_t, t):
method p_pred (line 182) | def p_pred(self, log_x, t):
method log_sample_categorical (line 187) | def log_sample_categorical(self, logits):
method q_sample (line 194) | def q_sample(self, log_x_start, t):
method kl_prior (line 199) | def kl_prior(self, log_x_start):
method sample_time (line 210) | def sample_time(self, b, device, method='uniform'):
method forward (line 233) | def forward(self, x, input_data):
method sample (line 287) | def sample(self, x):
FILE: layers/Latent_Level/stage2/gen_denoise.py
function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=128):
method forward (line 64) | def forward(self, x, t):
class DDCM (line 91) | class DDCM(nn.Module):
method __init__ (line 92) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
method forward (line 106) | def forward(self, x):
function l2norm (line 124) | def l2norm(t):
class Attention (line 127) | class Attention(nn.Module):
method __init__ (line 128) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 135) | def forward(self, x):
class Cross_Attention (line 148) | class Cross_Attention(nn.Module):
method __init__ (line 149) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 159) | def forward(self, x, cond_x):
class DownBlock (line 175) | class DownBlock(nn.Module):
method __init__ (line 176) | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=...
method forward (line 193) | def forward(self, x, t):
class UpBlock (line 201) | class UpBlock(nn.Module):
method __init__ (line 202) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
method forward (line 231) | def forward(self, x, residual, t):
function timestep_embedding (line 258) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
class Denoise (line 280) | class Denoise(nn.Module):
method __init__ (line 281) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
method forward (line 324) | def forward(self, x, t):
FILE: layers/Voxel_Level/Con_Diffusion.py
function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
function log_1_min_a (line 18) | def log_1_min_a(a):
function log_add_exp (line 22) | def log_add_exp(a, b):
function exists (line 27) | def exists(x):
function extract (line 31) | def extract(a, t, x_shape):
function default (line 37) | def default(val, d):
function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
function log_onehot_to_index (line 57) | def log_onehot_to_index(log_x):
function cosine_beta_schedule (line 61) | def cosine_beta_schedule(timesteps, s = 0.008):
class Con_Diffusion (line 77) | class Con_Diffusion(torch.nn.Module):
method __init__ (line 78) | def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, a...
method device (line 117) | def device(self):
method multinomial_kl (line 120) | def multinomial_kl(self, log_prob1, log_prob2):
method q_pred_one_timestep (line 124) | def q_pred_one_timestep(self, log_x_t, t):
method q_pred (line 137) | def q_pred(self, log_x_start, t):
method predict_start (line 148) | def predict_start(self, log_x_t, t, cond):
method q_posterior (line 156) | def q_posterior(self, log_x_start, log_x_t, t):
method p_pred (line 177) | def p_pred(self, log_x, t, cond):
method log_sample_categorical (line 182) | def log_sample_categorical(self, logits):
method q_sample (line 189) | def q_sample(self, log_x_start, t):
method kl_prior (line 194) | def kl_prior(self, log_x_start):
method sample_time (line 205) | def sample_time(self, b, device, method='uniform'):
method forward (line 228) | def forward(self, x, voxel_input):
method sample (line 280) | def sample(self, voxel_input, intermediate=False):
FILE: layers/Voxel_Level/Gen_Diffusion.py
function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
function log_1_min_a (line 18) | def log_1_min_a(a):
function log_add_exp (line 22) | def log_add_exp(a, b):
function exists (line 27) | def exists(x):
function extract (line 31) | def extract(a, t, x_shape):
function default (line 37) | def default(val, d):
function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
function log_onehot_to_index (line 57) | def log_onehot_to_index(log_x):
function cosine_beta_schedule (line 61) | def cosine_beta_schedule(timesteps, s = 0.008):
class Diffusion (line 77) | class Diffusion(torch.nn.Module):
method __init__ (line 78) | def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, a...
method device (line 114) | def device(self):
method multinomial_kl (line 117) | def multinomial_kl(self, log_prob1, log_prob2):
method q_pred_one_timestep (line 121) | def q_pred_one_timestep(self, log_x_t, t):
method q_pred (line 134) | def q_pred(self, log_x_start, t):
method predict_start (line 145) | def predict_start(self, log_x_t, t):
method q_posterior (line 153) | def q_posterior(self, log_x_start, log_x_t, t):
method p_pred (line 174) | def p_pred(self, log_x, t):
method log_sample_categorical (line 179) | def log_sample_categorical(self, logits):
method q_sample (line 186) | def q_sample(self, log_x_start, t):
method kl_prior (line 191) | def kl_prior(self, log_x_start):
method sample_time (line 202) | def sample_time(self, b, device, method='uniform'):
method forward (line 225) | def forward(self, x, voxel_input):
method sample (line 277) | def sample(self, voxel_input):
FILE: layers/Voxel_Level/denoise.py
function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=32*4):
method forward (line 71) | def forward(self, x, t):
class DDCM (line 98) | class DDCM(nn.Module):
method __init__ (line 99) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
method forward (line 118) | def forward(self, x):
function l2norm (line 136) | def l2norm(t):
class Attention (line 139) | class Attention(nn.Module):
method __init__ (line 140) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 147) | def forward(self, x):
class Cross_Attention (line 160) | class Cross_Attention(nn.Module):
method __init__ (line 161) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 171) | def forward(self, x, cond_x):
class DownBlock (line 187) | class DownBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_...
method forward (line 201) | def forward(self, x, t):
class UpBlock (line 209) | class UpBlock(nn.Module):
method __init__ (line 210) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
method forward (line 245) | def forward(self, x, residual, t):
function timestep_embedding (line 272) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
class Denoise (line 286) | class Denoise(nn.Module):
method __init__ (line 287) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
method forward (line 322) | def forward(self, x, x_cond, t):
FILE: layers/Voxel_Level/gen_denoise.py
function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=128):
method forward (line 70) | def forward(self, x, t):
class DDCM (line 97) | class DDCM(nn.Module):
method __init__ (line 98) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
method forward (line 115) | def forward(self, x):
function l2norm (line 132) | def l2norm(t):
class Attention (line 135) | class Attention(nn.Module):
method __init__ (line 136) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 143) | def forward(self, x):
class Cross_Attention (line 156) | class Cross_Attention(nn.Module):
method __init__ (line 157) | def __init__(self, dim, heads = 4, scale = 10):
method forward (line 167) | def forward(self, x, cond_x):
class DownBlock (line 183) | class DownBlock(nn.Module):
method __init__ (line 184) | def __init__(self, in_filters, out_filters, time_filters, kernel_size=...
method forward (line 200) | def forward(self, x, t):
class UpBlock (line 208) | class UpBlock(nn.Module):
method __init__ (line 209) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
method forward (line 244) | def forward(self, x, residual, t):
function timestep_embedding (line 271) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
class Denoise (line 293) | class Denoise(nn.Module):
method __init__ (line 294) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
method forward (line 329) | def forward(self, x, t):
FILE: simple_visualize.py
function load_config (line 7) | def load_config(yaml_path):
function load_pointcloud (line 12) | def load_pointcloud(filepath, learning_map, color_map):
function main (line 29) | def main():
FILE: train.py
class Experiment (line 19) | class Experiment(object):
method __init__ (line 22) | def __init__(self, args, model, optimizer, scheduler_iter, scheduler_e...
method run (line 65) | def run(self, epochs):
method train_fn (line 112) | def train_fn(self, epoch):
method eval_fn (line 142) | def eval_fn(self, epoch):
method sample (line 162) | def sample(self):
method resume (line 215) | def resume(self):
method log_metrics (line 246) | def log_metrics(self, dict, type):
method save_metrics (line 254) | def save_metrics(self):
method checkpoint_save (line 270) | def checkpoint_save(self, epoch):
method checkpoint_load (line 292) | def checkpoint_load(self, resume_path):
FILE: utils/cuda.py
function find_free_port (line 11) | def find_free_port():
function launch (line 23) | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=...
function distributed_worker (line 54) | def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, ma...
function set_cuda_vd (line 94) | def set_cuda_vd(gpu_ids, verbose=True):
FILE: utils/dicts.py
function is_primary (line 13) | def is_primary():
function get_rank (line 17) | def get_rank():
function get_local_rank (line 27) | def get_local_rank():
function synchronize (line 40) | def synchronize():
function get_world_size (line 55) | def get_world_size():
function is_distributed (line 65) | def is_distributed():
function all_reduce (line 70) | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
function all_gather (line 80) | def all_gather(data):
function reduce_dict (line 115) | def reduce_dict(input_dict, average=True):
function data_sampler (line 140) | def data_sampler(dataset, shuffle, distributed):
function clean_dict (line 150) | def clean_dict(d, keys):
FILE: utils/intermediate_vis.py
class Vis_iter (line 19) | class Vis_iter(object):
method __init__ (line 22) | def __init__(self, args, model, optimizer, scheduler_iter, scheduler_e...
method run (line 69) | def run(self, epochs):
method sample (line 74) | def sample(self):
method checkpoint_load (line 84) | def checkpoint_load(self, resume_path):
FILE: utils/loss.py
function dice_coef (line 16) | def dice_coef(y_true, y_pred, smooth=1e-6):
function dice_coef_multilabel (line 22) | def dice_coef_multilabel(y_true, y_pred, numLabels=11):
function lovasz_grad (line 33) | def lovasz_grad(gt_sorted):
function lovasz_softmax (line 50) | def lovasz_softmax(probas, labels, classes='present', per_image=False, i...
function lovasz_softmax_flat (line 68) | def lovasz_softmax_flat(probas, labels, classes='present'):
function flatten_probas (line 99) | def flatten_probas(probas, labels, ignore=None):
function isnan (line 123) | def isnan(x):
function mean (line 127) | def mean(l, ignore_nan=False, empty=0):
FILE: utils/multistep.py
class LinearWarmupScheduler (line 5) | class LinearWarmupScheduler(_LRScheduler):
method __init__ (line 12) | def __init__(self, optimizer, total_epoch, last_epoch=-1):
method get_lr (line 16) | def get_lr(self):
function get_optim (line 21) | def get_optim(args, model):
FILE: utils/tables.py
function get_args_table (line 9) | def get_args_table(args_dict):
function get_miou_table (line 15) | def get_miou_table(args, label_to_names, miou):
function get_metric_table (line 21) | def get_metric_table(metric_dict, epochs):
function create_folders (line 29) | def create_folders(args):
function inter_vis (line 38) | def inter_vis(args, recons):
function visualization (line 49) | def visualization(args, recons, input_data, output, invalid, iteration):
function completion_vis (line 74) | def completion_vis(args, input_p, recons):
function iou_one_frame (line 82) | def iou_one_frame(pred, target, n_classes=23):
function get_result (line 94) | def get_result(args, for_mask, output, preds, SSC=True):
function save_args (line 127) | def save_args(args):
function print_completion (line 137) | def print_completion(num_correct, num_total, TP, FP, FN):
function print_result (line 154) | def print_result(args, label_to_names, num_correct, num_total, all_inter...
FILE: visualization.py
class SpheresApp (line 20) | class SpheresApp:
method __init__ (line 25) | def __init__(self, opt):
method _on_menu_before (line 63) | def _on_menu_before(self):
method _on_menu_quit (line 84) | def _on_menu_quit(self):
method _on_menu_scene (line 87) | def _on_menu_scene(self):
function get_voxel (line 108) | def get_voxel(opt):
function get_input (line 142) | def get_input(opt):
function main (line 150) | def main(opt):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (174K chars).
[
{
"path": ".gitignore",
"chars": 569,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE.txt",
"chars": 1062,
"preview": "MIT License\n\nCopyright (c) 2023 jumin\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof t"
},
{
"path": "README.md",
"chars": 3106,
"preview": "# Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data\n\n📌[Paper](http://arxiv.org/abs/2301.00527) \n"
},
{
"path": "SSC_train.py",
"chars": 8251,
"preview": "import argparse\nimport os\nimport warnings\nimport time\nimport torch\nfrom utils.intermediate_vis import Vis_iter\n\nfrom dat"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/carla.yaml",
"chars": 2116,
"preview": "color_map :\n 0 : [255, 255, 255] # None\n 1 : [70, 70, 70] # Building\n 2 : [100, 40, 40] # Fences\n 3 : [55, 9"
},
{
"path": "datasets/carla_dataset.py",
"chars": 8711,
"preview": "import os\nimport numpy as np\nimport random\nimport json\nimport yaml\nimport torch\nimport numba as nb\nfrom torch.utils.data"
},
{
"path": "datasets/data.py",
"chars": 3303,
"preview": "import os\nimport math\nimport torch\nimport numpy as np\nfrom torch.utils.data import DataLoader\nfrom datasets.carla_datase"
},
{
"path": "layers/Ablation/wo_diffusion.py",
"chars": 1291,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nfrom layers.Latent_Level.stage"
},
{
"path": "layers/Latent_Level/stage1/model.py",
"chars": 15109,
"preview": "import numpy as np\nimport math\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import rear"
},
{
"path": "layers/Latent_Level/stage1/vector_quantizer.py",
"chars": 2028,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass VectorQuantizer(nn.Module):\n\n def __ini"
},
{
"path": "layers/Latent_Level/stage1/vqvae.py",
"chars": 2498,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nimport math\nfrom utils.loss im"
},
{
"path": "layers/Latent_Level/stage2/Gen_diffusion.py",
"chars": 11604,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.L"
},
{
"path": "layers/Latent_Level/stage2/gen_denoise.py",
"chars": 12611,
"preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
},
{
"path": "layers/Voxel_Level/Con_Diffusion.py",
"chars": 11181,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.V"
},
{
"path": "layers/Voxel_Level/Gen_Diffusion.py",
"chars": 10731,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.V"
},
{
"path": "layers/Voxel_Level/denoise.py",
"chars": 13160,
"preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
},
{
"path": "layers/Voxel_Level/gen_denoise.py",
"chars": 13151,
"preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
},
{
"path": "layers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "requirements.txt",
"chars": 98,
"preview": "numpy\ntorch\nscipy\nscikit-learn\nmatplotlib\ntqdm\nopen3d\npyyaml\nprettytable\ntensorboard\nnumba\neinops\n"
},
{
"path": "setup.py",
"chars": 465,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name=\"scene_scale_diffusion\",\n version=\"0.1\",\n author=\"Lee"
},
{
"path": "simple_visualize.py",
"chars": 1707,
"preview": "import os\nimport numpy as np\nimport open3d as o3d\nimport argparse\nimport yaml\n\ndef load_config(yaml_path):\n with open"
},
{
"path": "train.py",
"chars": 14831,
"preview": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimp"
},
{
"path": "utils/cuda.py",
"chars": 2762,
"preview": "import os\n\nimport os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nimport"
},
{
"path": "utils/dicts.py",
"chars": 3314,
"preview": "import copy\nimport math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n"
},
{
"path": "utils/intermediate_vis.py",
"chars": 3954,
"preview": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimp"
},
{
"path": "utils/loss.py",
"chars": 4811,
"preview": "import math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport numpy as np\ntry:\n "
},
{
"path": "utils/multistep.py",
"chars": 1587,
"preview": "import torch.optim as optim\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.optim.lr_scheduler import _LRSch"
},
{
"path": "utils/tables.py",
"chars": 7019,
"preview": "from prettytable import PrettyTable\nimport torch\nimport os\nimport pickle\nimport numpy as np\nimport torch.nn.functional a"
},
{
"path": "visualization.py",
"chars": 6147,
"preview": "import os\r\nimport open3d as o3d\r\nimport open3d.visualization.gui as gui\r\nimport open3d.visualization.rendering as render"
}
]
About this extraction
This page contains the full source code of the zoomin-lee/scene-scale-diffusion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (163.3 KB), approximately 45.8k tokens, and a symbol index with 295 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.