Repository: edwardzhou130/Panoptic-PolarNet
Branch: main
Commit: 3a72f2380a4e
Files: 24
Total size: 52.7 MB
Directory structure:
gitextract_pewizc2i/
├── LICENSE
├── README.md
├── configs/
│ └── SemanticKITTI_model/
│ └── Panoptic-PolarNet.yaml
├── data/
│ └── README.md
├── dataloader/
│ ├── __init__.py
│ ├── dataset.py
│ ├── instance_augmentation.py
│ └── process_panoptic.py
├── instance_preprocess.py
├── network/
│ ├── BEV_Unet.py
│ ├── __init__.py
│ ├── instance_post_processing.py
│ ├── loss.py
│ ├── lovasz_losses.py
│ └── ptBEV.py
├── pretrained_weight/
│ └── Panoptic_SemKITTI_PolarNet.pt
├── requirements.txt
├── semantic-kitti.yaml
├── test_pretrain.py
├── train.py
└── utils/
├── __init__.py
├── configs.py
├── eval_pq.py
└── visual.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2020, Zixiang Zhou
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# Panoptic-PolarNet
This is the official implementation of Panoptic-PolarNet.
[[**ArXiv paper**]](https://arxiv.org/abs/2103.14962)
# Introduction
Panoptic-PolarNet is a fast and robust LiDAR point cloud panoptic segmentation framework. We learn both semantic segmentation and class-agnostic instance clustering in a single inference network using a polar Bird's Eye View (BEV) representation. Predictions from the semantic and instance head are then fused through a majority voting to create the final panopticsegmentation.
We test Panoptic-PolarNet on SemanticKITTI and nuScenes datasets. Experiment shows that Panoptic-PolarNet reaches state-of-the-art performances with a real-time inference speed.
## Prepare dataset and environment
This code is tested on Ubuntu 16.04 with Python 3.8, CUDA 10.2 and Pytorch 1.7.0.
1, Install the following dependencies by either `pip install -r requirements.txt` or manual installation.
* numpy
* pytorch
* tqdm
* yaml
* Cython
* [numba](https://github.com/numba/numba)
* [torch-scatter](https://github.com/rusty1s/pytorch_scatter)
* [dropblock](https://github.com/miguelvr/dropblock)
* (Optional) [open3d](https://github.com/intel-isl/Open3D)
2, Download Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).
3, Extract everything into the same folder. The folder structure inside the zip files of label data matches the folder structure of the LiDAR point cloud data.
4, Data file structure should look like this:
```
./
├── train.py
├── ...
└── data/
├──sequences
├── 00/
│ ├── velodyne/ # Unzip from KITTI Odometry Benchmark Velodyne point clouds.
| | ├── 000000.bin
| | ├── 000001.bin
| | └── ...
│ └── labels/ # Unzip from SemanticKITTI label data.
| ├── 000000.label
| ├── 000001.label
| └── ...
├── ...
└── 21/
└── ...
```
5, Instance preprocessing:
```shell
python instance_preprocess.py -d -o
```
## Training
Run
```shell
python train.py
```
The code will automatically train, validate and save the model that has the best validation PQ.
Panoptic-PolarNet with default setting requires around 11GB GPU memory for the training. Training model on GPU with less memory would likely cause GPU out-of-memory. In this case, you can set the ``grid_size`` in the config file to ``[320,240,32]`` or lower.
## Evaluate our pretrained model
We also provide a pretrained Panoptic-PolarNet weight.
```shell
python test_pretrain.py
```
Result will be stored in `./out` folder. Test performance can be evaluated by uploading label results onto the SemanticKITTI competition website [here](https://competitions.codalab.org/competitions/24025).
## Citation
Please cite our paper if this code benefits your research:
```
@inproceedings{Zhou2021PanopticPolarNet,
author={Zhou, Zixiang and Zhang, Yang and Foroosh, Hassan},
title={Panoptic-PolarNet: Proposal-free LiDAR Point Cloud Panoptic Segmentation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2021}
}
@InProceedings{Zhang_2020_CVPR,
author = {Zhang, Yang and Zhou, Zixiang and David, Philip and Yue, Xiangyu and Xi, Zerong and Gong, Boqing and Foroosh, Hassan},
title = {PolarNet: An Improved Grid Representation for Online LiDAR Point Clouds Semantic Segmentation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2020}
}
```
================================================
FILE: configs/SemanticKITTI_model/Panoptic-PolarNet.yaml
================================================
model_name: Panoptic_PolarNet
dataset:
name: semantickitti
path: data
output_path: out/SemKITTI
instance_pkl_path: data
rotate_aug: True
flip_aug: True
inst_aug: True
inst_aug_type:
inst_os: True
inst_loc_aug: True
inst_global_aug: True
gt_generator:
sigma: 5
grid_size: [480,360,32]
model:
model_save_path: ./Panoptic_SemKITTI.pt
pretrained_model: /pretrained_weight/Panoptic_SemKITTI_PolarNet.pt
polar: True
visibility: True
train_batch_size: 2
val_batch_size: 2
test_batch_size: 1
check_iter: 4000
max_epoch: 100
post_proc:
threshold: 0.1
nms_kernel: 5
top_k: 100
center_loss: MSE
offset_loss: L1
center_loss_weight: 100
offset_loss_weight: 10
enable_SAP: True
SAP:
start_epoch: 30
rate: 0.01
================================================
FILE: data/README.md
================================================
# PolarSeg
Download Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).
================================================
FILE: dataloader/__init__.py
================================================
================================================
FILE: dataloader/dataset.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SemKITTI dataloader
"""
import os
import numpy as np
import torch
import random
import time
import numba as nb
import yaml
import pickle
import errno
from torch.utils import data
from .process_panoptic import PanopticLabelGenerator
from .instance_augmentation import instance_augmentation
class SemKITTI(data.Dataset):
def __init__(self, data_path, imageset = 'train', return_ref = False, instance_pkl_path ='data'):
self.return_ref = return_ref
with open("semantic-kitti.yaml", 'r') as stream:
semkittiyaml = yaml.safe_load(stream)
self.learning_map = semkittiyaml['learning_map']
thing_class = semkittiyaml['thing_class']
self.thing_list = [cl for cl, ignored in thing_class.items() if ignored]
self.imageset = imageset
if imageset == 'train':
split = semkittiyaml['split']['train']
elif imageset == 'val':
split = semkittiyaml['split']['valid']
elif imageset == 'test':
split = semkittiyaml['split']['test']
else:
raise Exception('Split must be train/val/test')
self.im_idx = []
for i_folder in split:
self.im_idx += absoluteFilePaths('/'.join([data_path,str(i_folder).zfill(2),'velodyne']))
self.im_idx.sort()
# get class distribution weight
epsilon_w = 0.001
origin_class = semkittiyaml['content'].keys()
weights = np.zeros((len(semkittiyaml['learning_map_inv'])-1,),dtype = np.float32)
for class_num in origin_class:
if semkittiyaml['learning_map'][class_num] != 0:
weights[semkittiyaml['learning_map'][class_num]-1] += semkittiyaml['content'][class_num]
self.CLS_LOSS_WEIGHT = 1/(weights + epsilon_w)
self.instance_pkl_path = instance_pkl_path
def __len__(self):
'Denotes the total number of samples'
return len(self.im_idx)
def __getitem__(self, index):
raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
if self.imageset == 'test':
sem_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=int),axis=1)
inst_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=np.uint32),axis=1)
else:
annotated_data = np.fromfile(self.im_idx[index].replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))
sem_data = annotated_data & 0xFFFF #delete high 16 digits binary
sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)
inst_data = annotated_data
data_tuple = (raw_data[:,:3], sem_data.astype(np.uint8),inst_data)
if self.return_ref:
data_tuple += (raw_data[:,3],)
return data_tuple
def save_instance(self, out_dir, min_points = 10):
'instance data preparation'
instance_dict={label:[] for label in self.thing_list}
for data_path in self.im_idx:
print('process instance for:'+data_path)
# get x,y,z,ref,semantic label and instance label
raw_data = np.fromfile(data_path, dtype=np.float32).reshape((-1, 4))
annotated_data = np.fromfile(data_path.replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))
sem_data = annotated_data & 0xFFFF #delete high 16 digits binary
sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)
inst_data = annotated_data
# instance mask
mask = np.zeros_like(sem_data,dtype=bool)
for label in self.thing_list:
mask[sem_data == label] = True
# create unqiue instance list
inst_label = inst_data[mask].squeeze()
unique_label = np.unique(inst_label)
num_inst = len(unique_label)
inst_count = 0
for inst in unique_label:
# get instance index
index = np.where(inst_data == inst)[0]
# get semantic label
class_label = sem_data[index[0]]
# skip small instance
if index.size 1:
inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)
# process panoptic
processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label
inst_voxel_pair = np.concatenate([grid_ind[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)
inst_voxel_pair = inst_voxel_pair[np.lexsort((grid_ind[mask[:,0],0],grid_ind[mask[:,0],1])),:]
processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)
else:
processed_inst = None
center,center_points,offset = self.panoptic_proc(insts[mask],xyz[mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)
data_tuple = (voxel_position,processed_label,center,offset)
# center data on each voxel for PTnet
voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
return_xyz = xyz - voxel_centers
return_xyz = np.concatenate((return_xyz,xyz),axis = 1)
if len(data) == 3:
return_fea = return_xyz
elif len(data) == 4:
return_fea = np.concatenate((return_xyz,feat),axis = 1)
if self.return_test:
data_tuple += (grid_ind,labels,insts,return_fea,index)
else:
data_tuple += (grid_ind,labels,insts,return_fea)
return data_tuple
# transformation between Cartesian coordinates and polar coordinates
def cart2polar(input_xyz):
rho = np.sqrt(input_xyz[:,0]**2 + input_xyz[:,1]**2)
phi = np.arctan2(input_xyz[:,1],input_xyz[:,0])
return np.stack((rho,phi,input_xyz[:,2]),axis=1)
def polar2cat(input_xyz_polar):
x = input_xyz_polar[0]*np.cos(input_xyz_polar[1])
y = input_xyz_polar[0]*np.sin(input_xyz_polar[1])
return np.stack((x,y,input_xyz_polar[2]),axis=0)
class spherical_dataset(data.Dataset):
def __init__(self, in_dataset, args, grid_size, ignore_label = 0, return_test = False, use_aug = False, fixed_volume_space= True, max_volume_space = [50,np.pi,1.5], min_volume_space = [3,-np.pi,-3]):
'Initialization'
self.point_cloud_dataset = in_dataset
self.grid_size = np.asarray(grid_size)
self.rotate_aug = args['rotate_aug'] if use_aug else False
self.flip_aug = args['flip_aug'] if use_aug else False
self.instance_aug = args['inst_aug'] if use_aug else False
self.ignore_label = ignore_label
self.return_test = return_test
self.fixed_volume_space = fixed_volume_space
self.max_volume_space = max_volume_space
self.min_volume_space = min_volume_space
self.panoptic_proc = PanopticLabelGenerator(self.grid_size,sigma=args['gt_generator']['sigma'],polar=True)
if self.instance_aug:
self.inst_aug = instance_augmentation(self.point_cloud_dataset.instance_pkl_path+'/instance_path.pkl',self.point_cloud_dataset.thing_list,self.point_cloud_dataset.CLS_LOSS_WEIGHT,\
random_flip=args['inst_aug_type']['inst_global_aug'],random_add=args['inst_aug_type']['inst_os'],\
random_rotate=args['inst_aug_type']['inst_global_aug'],local_transformation=args['inst_aug_type']['inst_loc_aug'])
def __len__(self):
'Denotes the total number of samples'
return len(self.point_cloud_dataset)
def __getitem__(self, index):
'Generates one sample of data'
data = self.point_cloud_dataset[index]
if len(data) == 3:
xyz,labels,insts = data
elif len(data) == 4:
xyz,labels,insts,feat = data
if len(feat.shape) == 1: feat = feat[..., np.newaxis]
else: raise Exception('Return invalid data tuple')
if len(labels.shape) == 1: labels = labels[..., np.newaxis]
if len(insts.shape) == 1: insts = insts[..., np.newaxis]
# random data augmentation by rotation
if self.rotate_aug:
rotate_rad = np.deg2rad(np.random.random()*360)
c, s = np.cos(rotate_rad), np.sin(rotate_rad)
j = np.matrix([[c, s], [-s, c]])
xyz[:,:2] = np.dot( xyz[:,:2],j)
# random data augmentation by flip x , y or x+y
if self.flip_aug:
flip_type = np.random.choice(4,1)
if flip_type==1:
xyz[:,0] = -xyz[:,0]
elif flip_type==2:
xyz[:,1] = -xyz[:,1]
elif flip_type==3:
xyz[:,:2] = -xyz[:,:2]
# random instance augmentation
if self.instance_aug:
xyz,labels,insts,feat = self.inst_aug.instance_aug(xyz,labels.squeeze(),insts.squeeze(),feat)
# convert coordinate into polar coordinates
xyz_pol = cart2polar(xyz)
max_bound_r = np.percentile(xyz_pol[:,0],100,axis = 0)
min_bound_r = np.percentile(xyz_pol[:,0],0,axis = 0)
max_bound = np.max(xyz_pol[:,1:],axis = 0)
min_bound = np.min(xyz_pol[:,1:],axis = 0)
max_bound = np.concatenate(([max_bound_r],max_bound))
min_bound = np.concatenate(([min_bound_r],min_bound))
if self.fixed_volume_space:
max_bound = np.asarray(self.max_volume_space)
min_bound = np.asarray(self.min_volume_space)
# get grid index
crop_range = max_bound - min_bound
cur_grid_size = self.grid_size
intervals = crop_range/(cur_grid_size-1)
if (intervals==0).any(): print("Zero interval!")
grid_ind = (np.floor((np.clip(xyz_pol,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)
current_grid = grid_ind[:np.size(labels)]
# process voxel position
voxel_position = np.zeros(self.grid_size,dtype = np.float32)
dim_array = np.ones(len(self.grid_size)+1,int)
dim_array[0] = -1
voxel_position = np.indices(self.grid_size)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)
# voxel_position = polar2cat(voxel_position)
# process labels
processed_label = np.ones(self.grid_size,dtype = np.uint8)*self.ignore_label
label_voxel_pair = np.concatenate([current_grid,labels],axis = 1)
label_voxel_pair = label_voxel_pair[np.lexsort((current_grid[:,0],current_grid[:,1],current_grid[:,2])),:]
processed_label = nb_process_label(np.copy(processed_label),label_voxel_pair)
# data_tuple = (voxel_position,processed_label)
# get thing points mask
mask = np.zeros_like(labels,dtype=bool)
for label in self.point_cloud_dataset.thing_list:
mask[labels == label] = True
inst_label = insts[mask].squeeze()
unique_label = np.unique(inst_label)
unique_label_dict = {label:idx+1 for idx , label in enumerate(unique_label)}
if inst_label.size > 1:
inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)
# process panoptic
processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label
inst_voxel_pair = np.concatenate([current_grid[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)
inst_voxel_pair = inst_voxel_pair[np.lexsort((current_grid[mask[:,0],0],current_grid[mask[:,0],1])),:]
processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)
else:
processed_inst = None
center,center_points,offset = self.panoptic_proc(insts[mask],xyz[:np.size(labels)][mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)
# prepare visiblity feature
# find max distance index in each angle,height pair
valid_label = np.zeros_like(processed_label,dtype=bool)
valid_label[current_grid[:,0],current_grid[:,1],current_grid[:,2]] = True
valid_label = valid_label[::-1]
max_distance_index = np.argmax(valid_label,axis=0)
max_distance = max_bound[0]-intervals[0]*(max_distance_index)
distance_feature = np.expand_dims(max_distance, axis=2)-np.transpose(voxel_position[0],(1,2,0))
distance_feature = np.transpose(distance_feature,(1,2,0))
# convert to boolean feature
distance_feature = (distance_feature>0)*-1.
distance_feature[current_grid[:,2],current_grid[:,0],current_grid[:,1]]=1.
data_tuple = (distance_feature,processed_label,center,offset)
# center data on each voxel for PTnet
voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
return_xyz = xyz_pol - voxel_centers
return_xyz = np.concatenate((return_xyz,xyz_pol,xyz[:,:2]),axis = 1)
if len(data) == 3:
return_fea = return_xyz
elif len(data) == 4:
return_fea = np.concatenate((return_xyz,feat),axis = 1)
if self.return_test:
data_tuple += (grid_ind,labels,insts,return_fea,index)
else:
data_tuple += (grid_ind,labels,insts,return_fea)
return data_tuple
@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])',nopython=True,cache=True,parallel = False)
def nb_process_label(processed_label,sorted_label_voxel_pair):
label_size = 256
counter = np.zeros((label_size,),dtype = np.uint16)
counter[sorted_label_voxel_pair[0,3]] = 1
cur_sear_ind = sorted_label_voxel_pair[0,:3]
for i in range(1,sorted_label_voxel_pair.shape[0]):
cur_ind = sorted_label_voxel_pair[i,:3]
if not np.all(np.equal(cur_ind,cur_sear_ind)):
processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)
counter = np.zeros((label_size,),dtype = np.uint16)
cur_sear_ind = cur_ind
counter[sorted_label_voxel_pair[i,3]] += 1
processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)
return processed_label
@nb.jit('u1[:,:](u1[:,:],i8[:,:])',nopython=True,cache=True,parallel = False)
def nb_process_inst(processed_inst,sorted_inst_voxel_pair):
label_size = 256
counter = np.zeros((label_size,),dtype = np.uint16)
counter[sorted_inst_voxel_pair[0,2]] = 1
cur_sear_ind = sorted_inst_voxel_pair[0,:2]
for i in range(1,sorted_inst_voxel_pair.shape[0]):
cur_ind = sorted_inst_voxel_pair[i,:2]
if not np.all(np.equal(cur_ind,cur_sear_ind)):
processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)
counter = np.zeros((label_size,),dtype = np.uint16)
cur_sear_ind = cur_ind
counter[sorted_inst_voxel_pair[i,2]] += 1
processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)
return processed_inst
def collate_fn_BEV(data):
data2stack=np.stack([d[0] for d in data]).astype(np.float32)
label2stack=np.stack([d[1] for d in data])
center2stack=np.stack([d[2] for d in data])
offset2stack=np.stack([d[3] for d in data])
grid_ind_stack = [d[4] for d in data]
point_label = [d[5] for d in data]
point_inst = [d[6] for d in data]
xyz = [d[7] for d in data]
return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz
def collate_fn_BEV_test(data):
data2stack=np.stack([d[0] for d in data]).astype(np.float32)
label2stack=np.stack([d[1] for d in data])
center2stack=np.stack([d[2] for d in data])
offset2stack=np.stack([d[3] for d in data])
grid_ind_stack = [d[4] for d in data]
point_label = [d[5] for d in data]
point_inst = [d[6] for d in data]
xyz = [d[7] for d in data]
index = [d[8] for d in data]
return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz,index
# load Semantic KITTI class info
with open("semantic-kitti.yaml", 'r') as stream:
semkittiyaml = yaml.safe_load(stream)
SemKITTI_label_name = dict()
for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]:
SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i]
================================================
FILE: dataloader/instance_augmentation.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import pickle
class instance_augmentation(object):
def __init__(self,instance_pkl_path,thing_list,class_weight,random_flip = False,random_add = False,random_rotate = False,local_transformation = False):
self.thing_list = thing_list
self.instance_weight = [class_weight[thing_class_num-1] for thing_class_num in thing_list]
self.instance_weight = np.asarray(self.instance_weight)/np.sum(self.instance_weight)
self.random_flip = random_flip
self.random_add = random_add
self.random_rotate = random_rotate
self.local_transformation = local_transformation
self.add_num = 5
with open(instance_pkl_path, 'rb') as f:
self.instance_path = pickle.load(f)
def instance_aug(self, point_xyz, point_label, point_inst, point_feat = None):
"""random rotate and flip each instance independently.
Args:
point_xyz: [N, 3], point location
point_label: [N, 1], class label
point_inst: [N, 1], instance label
"""
# random add instance to this scan
if self.random_add:
# choose which instance to add
instance_choice = np.random.choice(len(self.thing_list),self.add_num,replace=True,p=self.instance_weight)
uni_inst, uni_inst_count = np.unique(instance_choice,return_counts=True)
add_idx = 1
total_point_num = 0
early_break = False
for n, count in zip(uni_inst, uni_inst_count):
# find random instance
random_choice = np.random.choice(len(self.instance_path[self.thing_list[n]]),count)
# add to current scan
for idx in random_choice:
points = np.fromfile(self.instance_path[self.thing_list[n]][idx], dtype=np.float32).reshape((-1, 4))
add_xyz = points[:,:3]
center = np.mean(add_xyz,axis=0)
# need to check occlusion
fail_flag = True
if self.random_rotate:
# random rotate
random_choice = np.random.random(20)*np.pi*2
for r in random_choice:
center_r = self.rotate_origin(center[np.newaxis,...],r)
# check if occluded
if self.check_occlusion(point_xyz,center_r[0]):
fail_flag = False
break
# rotate to empty space
if fail_flag: continue
add_xyz = self.rotate_origin(add_xyz,r)
else:
fail_flag = not self.check_occlusion(point_xyz,center)
if fail_flag: continue
add_label = np.ones((points.shape[0],),dtype=np.uint8)*(self.thing_list[n])
add_inst = np.ones((points.shape[0],),dtype=np.uint32)*(add_idx<<16)
point_xyz = np.concatenate((point_xyz,add_xyz),axis=0)
point_label = np.concatenate((point_label,add_label),axis=0)
point_inst = np.concatenate((point_inst,add_inst),axis=0)
if point_feat is not None:
add_fea = points[:,3:]
if len(add_fea.shape) == 1: add_fea = add_fea[..., np.newaxis]
point_feat = np.concatenate((point_feat,add_fea),axis=0)
add_idx +=1
total_point_num += points.shape[0]
if total_point_num>5000:
early_break=True
break
# prevent adding too many points which cause GPU memory error
if early_break: break
# instance mask
mask = np.zeros_like(point_label,dtype=bool)
for label in self.thing_list:
mask[point_label == label] = True
# create unqiue instance list
inst_label = point_inst[mask].squeeze()
unique_label = np.unique(inst_label)
num_inst = len(unique_label)
for inst in unique_label:
# get instance index
index = np.where(point_inst == inst)[0]
# skip small instance
if index.size<10: continue
# get center
center = np.mean(point_xyz[index,:],axis=0)
if self.local_transformation:
# random translation and rotation
point_xyz[index,:] = self.local_tranform(point_xyz[index,:],center)
# random flip instance based on it center
if self.random_flip:
# get axis
long_axis = [center[0], center[1]]/(center[0]**2+center[1]**2)**0.5
short_axis = [-long_axis[1],long_axis[0]]
# random flip
flip_type = np.random.choice(5,1)
if flip_type==3:
point_xyz[index,:2] = self.instance_flip(point_xyz[index,:2],[long_axis,short_axis],[center[0], center[1]],flip_type)
# 20% random rotate
random_num = np.random.random_sample()
if self.random_rotate:
if random_num>0.8 and inst & 0xFFFF > 0:
random_choice = np.random.random(20)*np.pi*2
fail_flag = True
for r in random_choice:
center_r = self.rotate_origin(center[np.newaxis,...],r)
# check if occluded
if self.check_occlusion(np.delete(point_xyz, index, axis=0),center_r[0]):
fail_flag = False
break
if not fail_flag:
# rotate to empty space
point_xyz[index,:] = self.rotate_origin(point_xyz[index,:],r)
if len(point_label.shape) == 1: point_label = point_label[..., np.newaxis]
if len(point_inst.shape) == 1: point_inst = point_inst[..., np.newaxis]
if point_feat is not None:
return point_xyz,point_label,point_inst,point_feat
else:
return point_xyz,point_label,point_inst
def instance_flip(self, points,axis,center,flip_type = 1):
points = points[:]-center
if flip_type == 1:
# rotate 180 degree
points = -points+center
elif flip_type == 2:
# flip over long axis
a = axis[0][0]
b = axis[0][1]
flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])
points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))
points = np.transpose(points, (1, 0))+center
elif flip_type == 3:
# flip over short axis
a = axis[1][0]
b = axis[1][1]
flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])
points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))
points = np.transpose(points, (1, 0))+center
return points
def check_occlusion(self,points,center,min_dist=2):
'check if close to a point'
if points.ndim == 1:
dist = np.linalg.norm(points[np.newaxis,:]-center,axis=1)
else:
dist = np.linalg.norm(points-center,axis=1)
return np.all(dist>min_dist)
def rotate_origin(self,xyz,radians):
'rotate a point around the origin'
x = xyz[:,0]
y = xyz[:,1]
new_xyz = xyz.copy()
new_xyz[:,0] = x * np.cos(radians) + y * np.sin(radians)
new_xyz[:,1] = -x * np.sin(radians) + y * np.cos(radians)
return new_xyz
def local_tranform(self,xyz,center):
'translate and rotate point cloud according to its center'
# random xyz
loc_noise = np.random.normal(scale = 0.25, size=(1,3))
# random angle
rot_noise = np.random.uniform(-np.pi/20, np.pi/20)
xyz = xyz-center
xyz = self.rotate_origin(xyz,rot_noise)
xyz = xyz+loc_noise
return xyz+center
================================================
FILE: dataloader/process_panoptic.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
class PanopticLabelGenerator(object):
def __init__(self,grid_size,sigma=5,polar=False):
"""Initialize panoptic ground truth generator
Args:
grid_size: voxel size.
sigma (int, optional): Gaussian distribution paramter. Create heatmap in +-3*sigma window. Defaults to 5.
polar (bool, optional): Is under polar coordinate. Defaults to False.
"""
self.grid_size = grid_size
self.polar = polar
self.sigma = sigma
size = 6 * sigma + 3
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0, y0 = 3 * sigma + 1, 3 * sigma + 1
self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
def __call__(self,inst,xyz,voxel_inst,voxel_position,label_dict,min_bound,intervals):
"""Generate instance center and offset ground truth
Args:
inst : instance panoptic label (N)
xyz : point location (N x 3)
voxel_inst : voxel panoptic label on the BEV (H x W)
voxel_position : voxel location on the BEV (3 x H x W)
label_dict : unqiue instance label dict
min_bound : space minimal bound
intervals : voxelization intervals
Returns:
center, center_pts, offset
"""
height, width = self.grid_size[0],self.grid_size[1]
center = np.zeros((1, height, width), dtype=np.float32)
center_pts = []
offset = np.zeros((2, height, width), dtype=np.float32)
#skip empty instances
if inst.size < 2: return center, center_pts, offset
# find unique instances
inst_labels = np.unique(inst)
for inst_label in inst_labels:
# get mask for each unique instance
mask = np.where(inst == inst_label)
voxel_mask = np.where(voxel_inst == label_dict[inst_label])
# get center
center_x, center_y = np.mean(xyz[mask,0]), np.mean(xyz[mask,1])
if self.polar:
# convert to polar coordinate
center_x_pol, center_y_pol = np.sqrt(center_x**2 + center_y**2),np.arctan2(center_y,center_x)
center_x = center_x_pol
center_y = center_y_pol
# generate center heatmap
x, y = int(np.floor((center_x-min_bound[0])/intervals[0])), int(np.floor((center_y-min_bound[1])/intervals[1]))
center_pts.append([x, y])
# outside image boundary
if x < 0 or y < 0 or \
x >= height or y >= width:
continue
sigma = self.sigma
# upper left
ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1))
# bottom right
br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2))
if self.polar:
c, d = max(0, -ul[0]), min(br[0], height) - ul[0]
a, b = 0, br[1] - ul[1]
cc, dd = max(0, ul[0]), min(br[0], height)
angle_list = [angle_id % width for angle_id in range(ul[1],br[1])]
center[0, cc:dd, angle_list] = np.maximum(
center[0, cc:dd, angle_list], np.transpose(self.g[c:d,a:b]))
else:
c, d = max(0, -ul[0]), min(br[0], height) - ul[0]
a, b = max(0, -ul[1]), min(br[1], width) - ul[1]
cc, dd = max(0, ul[0]), min(br[0], height)
aa, bb = max(0, ul[1]), min(br[1], width)
center[0, cc:dd, aa:bb] = np.maximum(
center[0, cc:dd, aa:bb], self.g[c:d,a:b])
if self.polar:
# generate offset (2, h, w) -> (y-dir, x-dir)
offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]
offset[1,voxel_mask[0],voxel_mask[1]] = ((center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]]+np.pi)%(2*np.pi) - np.pi)/intervals[1]
else:
# generate offset (2, h, w) -> (y-dir, x-dir)
offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]
offset[1,voxel_mask[0],voxel_mask[1]] = (center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]])/intervals[1]
# print('gt center',center_pts)
return center, center_pts, offset
================================================
FILE: instance_preprocess.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from dataloader.dataset import SemKITTI
if __name__ == '__main__':
# instance preprocessing
parser = argparse.ArgumentParser(description='')
parser.add_argument('-d', '--data_path', default='data')
parser.add_argument('-o', '--out_path', default='data')
args = parser.parse_args()
train_pt_dataset = SemKITTI(args.data_path + '/sequences/', imageset = 'train', return_ref = True)
train_pt_dataset.save_instance(args.out_path)
print('instance preprocessing finished.')
================================================
FILE: network/BEV_Unet.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dropblock import DropBlock2D
class BEV_Unet(nn.Module):
def __init__(self,n_class,n_height,dilation = 1,group_conv=False,input_batch_norm = False,dropout = 0.,circular_padding = False, dropblock = True, use_vis_fea=False):
super(BEV_Unet, self).__init__()
self.n_class = n_class
self.n_height = n_height
if use_vis_fea:
self.network = UNet(n_class*n_height,2*n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)
else:
self.network = UNet(n_class*n_height,n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)
def forward(self, x):
x,center,offset = self.network(x)
x = x.permute(0,2,3,1)
new_shape = list(x.size())[:3] + [self.n_height,self.n_class]
x = x.view(new_shape)
x = x.permute(0,4,1,2,3)
return x,center,offset
class UNet(nn.Module):
def __init__(self, n_class,n_height,dilation,group_conv,input_batch_norm, dropout,circular_padding,dropblock):
super(UNet, self).__init__()
# encoder
self.inc = inconv(n_height, 64, dilation, input_batch_norm, circular_padding)
self.down1 = down(64, 128, dilation, group_conv, circular_padding)
self.down2 = down(128, 256, dilation, group_conv, circular_padding)
self.down3 = down(256, 512, dilation, group_conv, circular_padding)
self.down4 = down(512, 512, dilation, group_conv, circular_padding)
# semantic decoder
self.up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.up4 = up(128, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.dropout = nn.Dropout(p=0. if dropblock else dropout)
# semantic head
self.outc = outconv(64, n_class)
# instance decoder
# self.i_up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
# self.i_up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
# self.i_up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
# self.i_up4 = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.i_up4_center = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
self.i_up4_offset = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
# instance head
self.i_outc_center = outconv(32, 1)
self.i_outc_offset = outconv(32, 2)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
# semantic
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
s_x = self.up4(x, x1)
s_x = self.outc(self.dropout(s_x))
# instance
# i_x = self.i_up1(x5, x4)
# i_x = self.i_up2(i_x, x3)
# i_x = self.i_up3(i_x, x2)x
# i_x = self.i_up4(i_x, x1)
i_x_center = self.i_up4_center(x, x1)
i_x_center = self.i_outc_center(self.dropout(i_x_center))
i_x_offset = self.i_up4_offset(x, x1)
i_x_offset = self.i_outc_offset(self.dropout(i_x_offset))
return s_x, i_x_center, i_x_offset
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch,group_conv,dilation=1):
super(double_conv, self).__init__()
if group_conv:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1,groups = min(out_ch,in_ch)),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1,groups = out_ch),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
else:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class double_conv_circular(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch,group_conv,dilation=1):
super(double_conv_circular, self).__init__()
if group_conv:
self.conv1 = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=(1,0),groups = min(out_ch,in_ch)),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_ch, out_ch, 3, padding=(1,0),groups = out_ch),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=(1,0)),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_ch, out_ch, 3, padding=(1,0)),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(inplace=True)
)
def forward(self, x):
#add circular padding
x = F.pad(x,(1,1,0,0),mode = 'circular')
x = self.conv1(x)
x = F.pad(x,(1,1,0,0),mode = 'circular')
x = self.conv2(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch, dilation, input_batch_norm, circular_padding):
super(inconv, self).__init__()
if input_batch_norm:
if circular_padding:
self.conv = nn.Sequential(
nn.BatchNorm2d(in_ch),
double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)
)
else:
self.conv = nn.Sequential(
nn.BatchNorm2d(in_ch),
double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)
)
else:
if circular_padding:
self.conv = double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)
else:
self.conv = double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch, dilation, group_conv, circular_padding):
super(down, self).__init__()
if circular_padding:
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv_circular(in_ch, out_ch,group_conv = group_conv,dilation = dilation)
)
else:
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch,group_conv = group_conv,dilation = dilation)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, circular_padding, bilinear=True, group_conv=False, use_dropblock = False, drop_p = 0.5):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
elif group_conv:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2,groups = in_ch//2)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
if circular_padding:
self.conv = double_conv_circular(in_ch, out_ch,group_conv = group_conv)
else:
self.conv = double_conv(in_ch, out_ch,group_conv = group_conv)
self.use_dropblock = use_dropblock
if self.use_dropblock:
self.dropblock = DropBlock2D(block_size=7, drop_prob=drop_p)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
# for padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
if self.use_dropblock:
x = self.dropblock(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
================================================
FILE: network/__init__.py
================================================
================================================
FILE: network/instance_post_processing.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
import torch_scatter
def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=5, top_k=None, polar=False):
"""
Find the center points from the center heatmap.
Arguments:
ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
for consistent, we only support N=1.
threshold: A Float, threshold applied to center heatmap score.
nms_kernel: An Integer, NMS max pooling kernel size.
top_k: An Integer, top k centers to keep.
Returns:
A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).
"""
if ctr_hmp.size(0) != 1:
raise ValueError('Only supports inference for batch size = 1')
# thresholding, setting values below threshold to -1
ctr_hmp = F.threshold(ctr_hmp, threshold, -1)
# NMS
if polar:
nms_padding = (nms_kernel - 1) // 2
ctr_hmp_max_pooled = F.pad(ctr_hmp,(nms_padding,nms_padding,0,0),mode = 'circular')
ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp_max_pooled, kernel_size=nms_kernel, stride=1, padding=(nms_padding,0))
else:
nms_padding = (nms_kernel - 1) // 2
ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp, kernel_size=nms_kernel, stride=1, padding=nms_padding)
ctr_hmp[ctr_hmp != ctr_hmp_max_pooled] = -1
# squeeze first two dimensions
ctr_hmp = ctr_hmp.squeeze()
assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.'
# find non-zero elements
ctr_all = torch.nonzero(ctr_hmp > 0)
if top_k is None:
return ctr_all
elif ctr_all.size(0) < top_k:
return ctr_all
else:
# find top k centers.
top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k)
return torch.nonzero(ctr_hmp > top_k_scores[-1])
def group_pixels(ctr, offsets, polar=False):
"""
Gives each pixel in the image an instance id.
Arguments:
ctr: A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).
offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
Returns:
A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).
"""
if offsets.size(0) != 1:
raise ValueError('Only supports inference for batch size = 1')
offsets = offsets.squeeze(0)
height, width = offsets.size()[1:]
# generates a coordinate map, where each location is the coordinate of that loc
y_coord = torch.arange(height, dtype=offsets.dtype, device=offsets.device).repeat(1, width, 1).transpose(1, 2)
x_coord = torch.arange(width, dtype=offsets.dtype, device=offsets.device).repeat(1, height, 1)
coord = torch.cat((y_coord, x_coord), dim=0)
ctr_loc = coord + offsets
ctr_loc = ctr_loc.reshape((2, height * width)).transpose(1, 0)
# ctr: [K, 2] -> [K, 1, 2]
# ctr_loc = [H*W, 2] -> [1, H*W, 2]
ctr = ctr.unsqueeze(1)
ctr_loc = ctr_loc.unsqueeze(0)
# distance: [K, H*W]
distance = ctr - ctr_loc
if polar:
distance[:,:,0] = torch.add(torch.fmod(torch.add(distance[:,:,0],width/2),width),-width/2)
distance = torch.norm(distance, dim=-1)
# finds center with minimum distance at each location, offset by 1, to reserve id=0 for stuff
instance_id = torch.argmin(distance, dim=0).reshape((1, height, width)) + 1
return instance_id
def get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, threshold=0.1, nms_kernel=5, top_k=None,
thing_seg=None, polar=False):
"""
Post-processing for instance segmentation, gets class agnostic instance id map.
Arguments:
sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.
ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
for consistent, we only support N=1.
offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
thing_list: A List of thing class id.
threshold: A Float, threshold applied to center heatmap score.
nms_kernel: An Integer, NMS max pooling kernel size.
top_k: An Integer, top k centers to keep.
thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask, if not provided, inference from
semantic prediction.
Returns:
A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).
A Tensor of shape [1, K, 2] where K is the number of center points. The order of second dim is (y, x).
"""
# if thing_seg is None:
# # gets foreground segmentation
# thing_seg = torch.zeros_like(sem_seg)
# for thing_class in thing_list:
# thing_seg[sem_seg == thing_class] = 1
# if thing_seg.dim() == 4:
# # [1, H, W, Z] --> [1, H, W]
# thing_seg = torch.max(thing_seg,dim=3)
ctr = find_instance_center(ctr_hmp, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k, polar=polar)
if ctr.size(0) == 0:
return torch.zeros_like(thing_seg[:,:,:,0]), ctr.unsqueeze(0)
ins_seg = group_pixels(ctr, offsets, polar=polar)
return ins_seg, ctr.unsqueeze(0)
def merge_semantic_and_instance(sem_seg, sem, ins_seg, label_divisor, thing_list, void_label,thing_seg):
"""
Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic
instance segmentation label.
Arguments:
sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.
sem: A Tensor of shape [1, C, H, W, Z], predicted semantic logit.
ins_seg: A Tensor of shape [1, H, W], predicted instance label.
label_divisor: An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id.
thing_list: A List of thing class id.
void_label: An Integer, indicates the region has no confident prediction.
thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask.
Returns:
A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel).
Raises:
ValueError, if batch size is not 1.
"""
# In case thing mask does not align with semantic prediction
# semantic_thing_seg = torch.zeros_like(sem_seg,dtype=torch.bool)
# for thing_class in thing_list:
# semantic_thing_seg[sem_seg == thing_class] = True
# try to avoid the for loop
semantic_thing_seg = sem_seg<=max(thing_list)
ins_seg = torch.unsqueeze(ins_seg,3).expand_as(sem_seg)
thing_mask = (ins_seg > 0) & semantic_thing_seg & thing_seg
if not torch.nonzero(thing_mask).size(0) == 0:
sem_sum = torch_scatter.scatter_add(sem.permute(0,2,3,4,1)[thing_mask],ins_seg[thing_mask],dim=0)
class_id = torch.argmax(sem_sum[:,:max(thing_list)],dim=1)
sem_seg[thing_mask] = (ins_seg[thing_mask] * label_divisor) + class_id[ins_seg[thing_mask]]+1
else:
sem_seg[semantic_thing_seg & thing_seg] = void_label
return sem_seg
def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor=2**16, void_label=0,
threshold=0.1, nms_kernel=5, top_k=100, foreground_mask=None, polar=False):
"""
Post-processing for panoptic segmentation.
Arguments:
sem: A Tensor of shape [N, C, H, W, Z] of raw semantic output, where N is the batch size, for consistent,
we only support N=1.
ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
for consistent, we only support N=1.
offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
thing_list: A List of thing class id.
label_divisor: An Integer, used to convert panoptic id = instance_id * label_divisor + semantic_id.
void_label: An Integer, indicates the region has no confident prediction.
threshold: A Float, threshold applied to center heatmap score.
nms_kernel: An Integer, NMS max pooling kernel size.
top_k: An Integer, top k centers to keep.
foreground_mask: A processed Tensor of shape [N, H, W, Z], we only support N=1.
Returns:
A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel), int64.
Raises:
ValueError, if batch size is not 1.
"""
if sem.dim() != 5 and sem.dim() != 4:
raise ValueError('Semantic prediction with un-supported dimension: {}.'.format(sem.dim()))
if sem.dim() == 5 and sem.size(0) != 1:
raise ValueError('Only supports inference for batch size = 1')
if ctr_hmp.size(0) != 1:
raise ValueError('Only supports inference for batch size = 1')
if offsets.size(0) != 1:
raise ValueError('Only supports inference for batch size = 1')
if foreground_mask is not None:
if foreground_mask.dim() != 4:
raise ValueError('Foreground prediction with un-supported dimension: {}.'.format(sem.dim()))
if sem.dim() == 5:
semantic = torch.argmax(sem, dim=1)
# shift back to original label idx
semantic = torch.add(semantic, 1)
sem = F.softmax(sem)
else:
semantic = sem.type(torch.ByteTensor).cuda()
# shift back to original label idx
semantic = torch.add(semantic, 1).type(torch.LongTensor).cuda()
one_hot = torch.zeros((sem.size(0),torch.max(semantic).item()+1,sem.size(1),sem.size(2),sem.size(3))).cuda()
sem = one_hot.scatter_(1,torch.unsqueeze(semantic,1),1.)
sem = sem[:,1:,:,:,:]
if foreground_mask is not None:
thing_seg = foreground_mask
else:
thing_seg = None
instance, center = get_instance_segmentation(semantic, ctr_hmp, offsets, thing_list,
threshold=threshold, nms_kernel=nms_kernel, top_k=top_k,
thing_seg=thing_seg, polar=polar)
panoptic = merge_semantic_and_instance(semantic, sem, instance, label_divisor, thing_list, void_label, thing_seg)
return panoptic, center
================================================
FILE: network/loss.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
from .lovasz_losses import lovasz_softmax
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
(https://github.com/tianweiy/CenterPoint)
Arguments:
pred (batch x c x h x w)
gt (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
# loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
return - (pos_loss + neg_loss)
class FocalLoss(torch.nn.Module):
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
def forward(self, out, target):
return self.neg_loss(out, target)
class panoptic_loss(torch.nn.Module):
def __init__(self, ignore_label = 255, center_loss_weight = 100, offset_loss_weight = 1, center_loss = 'MSE', offset_loss = 'L1'):
super(panoptic_loss, self).__init__()
self.CE_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
assert center_loss in ['MSE','FocalLoss']
assert offset_loss in ['L1','SmoothL1']
if center_loss == 'MSE':
self.center_loss_fn = torch.nn.MSELoss()
elif center_loss == 'FocalLoss':
self.center_loss_fn = FocalLoss()
else: raise NotImplementedError
if offset_loss == 'L1':
self.offset_loss_fn = torch.nn.L1Loss()
elif offset_loss == 'SmoothL1':
self.offset_loss_fn = torch.nn.SmoothL1Loss()
else: raise NotImplementedError
self.center_loss_weight = center_loss_weight
self.offset_loss_weight = offset_loss_weight
print('Using '+ center_loss +' for heatmap regression, weight: '+str(center_loss_weight))
print('Using '+ offset_loss +' for offset regression, weight: '+str(offset_loss_weight))
self.lost_dict={'semantic_loss':[],
'heatmap_loss':[],
'offset_loss':[]}
def reset_loss_dict(self):
self.lost_dict={'semantic_loss':[],
'heatmap_loss':[],
'offset_loss':[]}
def forward(self,prediction,center,offset,gt_label,gt_center,gt_offset,save_loss = True):
# semantic loss
loss = lovasz_softmax(torch.nn.functional.softmax(prediction), gt_label,ignore=255) + self.CE_loss(prediction,gt_label)
if save_loss:
self.lost_dict['semantic_loss'].append(loss.item())
# center heatmap loss
center_mask = (gt_center>0) | (torch.min(torch.unsqueeze(gt_label, 1),dim=4)[0]<255)
center_loss = self.center_loss_fn(center,gt_center) * center_mask
# safe division
if center_mask.sum() > 0:
center_loss = center_loss.sum() / center_mask.sum() * self.center_loss_weight
else:
center_loss = center_loss.sum() * 0
if save_loss:
self.lost_dict['heatmap_loss'].append(center_loss.item())
loss += center_loss
# offset loss
offset_mask = gt_offset != 0
offset_loss = self.offset_loss_fn(offset,gt_offset) * offset_mask
# safe division
if offset_mask.sum() > 0:
offset_loss = offset_loss.sum() / offset_mask.sum() * self.offset_loss_weight
else:
offset_loss = offset_loss.sum() * 0
if save_loss:
self.lost_dict['offset_loss'].append(offset_loss.item())
loss += offset_loss
return loss
================================================
FILE: network/lovasz_losses.py
================================================
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from __future__ import print_function, division
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
if not union:
iou = EMPTY
else:
iou = float(intersection) / float(union)
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
return 100 * iou
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
"""
Array of IoU for each (non ignored) class
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / float(union))
ious.append(iou)
ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
return 100 * np.array(ious)
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def binary_xloss(logits, labels, ignore=None):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = StableBCELoss()(logits, Variable(labels.float()))
return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes is 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
elif probas.dim() == 5:
#3D segmentation
B, C, L, H, W = probas.size()
probas = probas.contiguous().view(B, C, L, H*W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
def xloss(logits, labels, ignore=None):
"""
Cross entropy loss
"""
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):
"""
Something wrong with this loss
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
true_1_hot = torch.eye(vprobas.shape[1])[vlabels]
if bk_class:
one_hot_assignment = torch.ones_like(vlabels)
one_hot_assignment[vlabels == bk_class] = 0
one_hot_assignment = one_hot_assignment.float().unsqueeze(1)
true_1_hot = true_1_hot*one_hot_assignment
true_1_hot = true_1_hot.to(vprobas.device)
intersection = torch.sum(vprobas * true_1_hot)
cardinality = torch.sum(vprobas + true_1_hot)
loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()
return (1-loss)*smooth
def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):
"""
Multi-class Hinge Jaccard loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
C = vprobas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
if c in vlabels:
c_sample_ind = vlabels == c
cprobas = vprobas[c_sample_ind,:]
non_c_ind =np.array([a for a in class_to_sum if a != c])
class_pred = cprobas[:,c]
max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]
TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth
FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)
if (~c_sample_ind).sum() == 0:
FP = 0
else:
nonc_probas = vprobas[~c_sample_ind,:]
class_pred = nonc_probas[:,c]
max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]
FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)
losses.append(1 - TP/(TP+FP+FN))
if len(losses) == 0: return 0
return mean(losses)
# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
return x != x
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
================================================
FILE: network/ptBEV.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numba as nb
import multiprocessing
import torch_scatter
class ptBEVnet(nn.Module):
def __init__(self, BEV_net, grid_size, pt_model = 'pointnet', fea_dim = 3, pt_pooling = 'max', kernal_size = 3,
out_pt_fea_dim = 64, max_pt_per_encode = 64, cluster_num = 4, pt_selection = 'farthest', fea_compre = None):
super(ptBEVnet, self).__init__()
assert pt_pooling in ['max']
assert pt_selection in ['random','farthest']
if pt_model == 'pointnet':
self.PPmodel = nn.Sequential(
nn.BatchNorm1d(fea_dim),
nn.Linear(fea_dim, 64),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Linear(64, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Linear(256, out_pt_fea_dim)
)
self.pt_model = pt_model
self.BEV_model = BEV_net
self.pt_pooling = pt_pooling
self.max_pt = max_pt_per_encode
self.pt_selection = pt_selection
self.fea_compre = fea_compre
self.grid_size = grid_size
# NN stuff
if kernal_size != 1:
if self.pt_pooling == 'max':
self.local_pool_op = torch.nn.MaxPool2d(kernal_size, stride=1, padding=(kernal_size-1)//2, dilation=1)
else: raise NotImplementedError
else: self.local_pool_op = None
# parametric pooling
if self.pt_pooling == 'max':
self.pool_dim = out_pt_fea_dim
# point feature compression
if self.fea_compre is not None:
self.fea_compression = nn.Sequential(
nn.Linear(self.pool_dim, self.fea_compre),
nn.ReLU())
self.pt_fea_dim = self.fea_compre
else:
self.pt_fea_dim = self.pool_dim
def forward(self, pt_fea, xy_ind, voxel_fea=None):
cur_dev = pt_fea[0].get_device()
# concate everything
cat_pt_ind = []
for i_batch in range(len(xy_ind)):
cat_pt_ind.append(F.pad(xy_ind[i_batch],(1,0),'constant',value = i_batch))
cat_pt_fea = torch.cat(pt_fea,dim = 0)
cat_pt_ind = torch.cat(cat_pt_ind,dim = 0)
pt_num = cat_pt_ind.shape[0]
# shuffle the data
shuffled_ind = torch.randperm(pt_num,device = cur_dev)
cat_pt_fea = cat_pt_fea[shuffled_ind,:]
cat_pt_ind = cat_pt_ind[shuffled_ind,:]
# unique xy grid index
unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind,return_inverse=True, return_counts=True, dim=0)
unq = unq.type(torch.int64)
# subsample pts
if self.pt_selection == 'random':
grp_ind = grp_range_torch(unq_cnt,cur_dev)[torch.argsort(torch.argsort(unq_inv))]
remain_ind = grp_ind < self.max_pt
elif self.pt_selection == 'farthest':
unq_ind = np.split(np.argsort(unq_inv.detach().cpu().numpy()), np.cumsum(unq_cnt.detach().cpu().numpy()[:-1]))
remain_ind = np.zeros((pt_num,),dtype = np.bool)
np_cat_fea = cat_pt_fea.detach().cpu().numpy()[:,:3]
pool_in = []
for i_inds in unq_ind:
if len(i_inds) > self.max_pt:
pool_in.append((np_cat_fea[i_inds,:],self.max_pt))
if len(pool_in) > 0:
pool = multiprocessing.Pool(multiprocessing.cpu_count())
FPS_results = pool.starmap(parallel_FPS, pool_in)
pool.close()
pool.join()
count = 0
for i_inds in unq_ind:
if len(i_inds) <= self.max_pt:
remain_ind[i_inds] = True
else:
remain_ind[i_inds[FPS_results[count]]] = True
count += 1
cat_pt_fea = cat_pt_fea[remain_ind,:]
cat_pt_ind = cat_pt_ind[remain_ind,:]
unq_inv = unq_inv[remain_ind]
unq_cnt = torch.clamp(unq_cnt,max=self.max_pt)
# process feature
if self.pt_model == 'pointnet':
processed_cat_pt_fea = self.PPmodel(cat_pt_fea)
if self.pt_pooling == 'max':
pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]
else: raise NotImplementedError
if self.fea_compre:
processed_pooled_data = self.fea_compression(pooled_data)
else:
processed_pooled_data = pooled_data
# stuff pooled data into 4D tensor
out_data_dim = [len(pt_fea),self.grid_size[0],self.grid_size[1],self.pt_fea_dim]
out_data = torch.zeros(out_data_dim, dtype=torch.float32).to(cur_dev)
out_data[unq[:,0],unq[:,1],unq[:,2],:] = processed_pooled_data
out_data = out_data.permute(0,3,1,2)
if self.local_pool_op != None:
out_data = self.local_pool_op(out_data)
if voxel_fea is not None:
out_data = torch.cat((out_data, voxel_fea), 1)
# run through network
sem_prediction, center, offset = self.BEV_model(out_data)
return sem_prediction, center, offset
def grp_range_torch(a,dev):
idx = torch.cumsum(a,0)
id_arr = torch.ones(idx[-1],dtype = torch.int64,device=dev)
id_arr[0] = 0
id_arr[idx[:-1]] = -a[:-1]+1
return torch.cumsum(id_arr,0)
def parallel_FPS(np_cat_fea,K):
return nb_greedy_FPS(np_cat_fea,K)
@nb.jit('b1[:](f4[:,:],i4)',nopython=True,cache=True)
def nb_greedy_FPS(xyz,K):
start_element = 0
sample_num = xyz.shape[0]
sum_vec = np.zeros((sample_num,1),dtype = np.float32)
xyz_sq = xyz**2
for j in range(sample_num):
sum_vec[j,0] = np.sum(xyz_sq[j,:])
pairwise_distance = sum_vec + np.transpose(sum_vec) - 2*np.dot(xyz, np.transpose(xyz))
candidates_ind = np.zeros((sample_num,),dtype = np.bool_)
candidates_ind[start_element] = True
remain_ind = np.ones((sample_num,),dtype = np.bool_)
remain_ind[start_element] = False
all_ind = np.arange(sample_num)
for i in range(1,K):
if i == 1:
min_remain_pt_dis = pairwise_distance[:,start_element]
min_remain_pt_dis = min_remain_pt_dis[remain_ind]
else:
cur_dis = pairwise_distance[remain_ind,:]
cur_dis = cur_dis[:,candidates_ind]
min_remain_pt_dis = np.zeros((cur_dis.shape[0],),dtype = np.float32)
for j in range(cur_dis.shape[0]):
min_remain_pt_dis[j] = np.min(cur_dis[j,:])
next_ind_in_remain = np.argmax(min_remain_pt_dis)
next_ind = all_ind[remain_ind][next_ind_in_remain]
candidates_ind[next_ind] = True
remain_ind[next_ind] = False
return candidates_ind
================================================
FILE: pretrained_weight/Panoptic_SemKITTI_PolarNet.pt
================================================
[File too large to display: 52.5 MB]
================================================
FILE: requirements.txt
================================================
numpy
torch>=1.7.0
tqdm
pyyaml
numba>=0.39.0
torch_scatter>=1.3.1
Cython
scipy
dropblock
================================================
FILE: semantic-kitti.yaml
================================================
# This file is covered by the LICENSE file in the root of this project.
labels:
0 : "unlabeled"
1 : "outlier"
10: "car"
11: "bicycle"
13: "bus"
15: "motorcycle"
16: "on-rails"
18: "truck"
20: "other-vehicle"
30: "person"
31: "bicyclist"
32: "motorcyclist"
40: "road"
44: "parking"
48: "sidewalk"
49: "other-ground"
50: "building"
51: "fence"
52: "other-structure"
60: "lane-marking"
70: "vegetation"
71: "trunk"
72: "terrain"
80: "pole"
81: "traffic-sign"
99: "other-object"
252: "moving-car"
253: "moving-bicyclist"
254: "moving-person"
255: "moving-motorcyclist"
256: "moving-on-rails"
257: "moving-bus"
258: "moving-truck"
259: "moving-other-vehicle"
color_map: # bgr
0 : [0, 0, 0]
1 : [0, 0, 255]
10: [245, 150, 100]
11: [245, 230, 100]
13: [250, 80, 100]
15: [150, 60, 30]
16: [255, 0, 0]
18: [180, 30, 80]
20: [255, 0, 0]
30: [30, 30, 255]
31: [200, 40, 255]
32: [90, 30, 150]
40: [255, 0, 255]
44: [255, 150, 255]
48: [75, 0, 75]
49: [75, 0, 175]
50: [0, 200, 255]
51: [50, 120, 255]
52: [0, 150, 255]
60: [170, 255, 150]
70: [0, 175, 0]
71: [0, 60, 135]
72: [80, 240, 150]
80: [150, 240, 255]
81: [0, 0, 255]
99: [255, 255, 50]
252: [245, 150, 100]
256: [255, 0, 0]
253: [200, 40, 255]
254: [30, 30, 255]
255: [90, 30, 150]
257: [250, 80, 100]
258: [180, 30, 80]
259: [255, 0, 0]
content: # as a ratio with the total number of points
0: 0.018889854628292943
1: 0.0002937197336781505
10: 0.040818519255974316
11: 0.00016609538710764618
13: 2.7879693665067774e-05
15: 0.00039838616015114444
16: 0.0
18: 0.0020633612104619787
20: 0.0016218197275284021
30: 0.00017698551338515307
31: 1.1065903904919655e-08
32: 5.532951952459828e-09
40: 0.1987493871255525
44: 0.014717169549888214
48: 0.14392298360372
49: 0.0039048553037472045
50: 0.1326861944777486
51: 0.0723592229456223
52: 0.002395131480328884
60: 4.7084144280367186e-05
70: 0.26681502148037506
71: 0.006035012012626033
72: 0.07814222006271769
80: 0.002855498193863172
81: 0.0006155958086189918
99: 0.009923127583046915
252: 0.001789309418528068
253: 0.00012709999297008662
254: 0.00016059776092534436
255: 3.745553104802113e-05
256: 0.0
257: 0.00011351574470342043
258: 0.00010157861367183268
259: 4.3840131989471124e-05
# classes that are indistinguishable from single scan or inconsistent in
# ground truth are mapped to their closest equivalent
learning_map:
0 : 0 # "unlabeled"
1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
10: 1 # "car"
11: 2 # "bicycle"
13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
15: 3 # "motorcycle"
16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
18: 4 # "truck"
20: 5 # "other-vehicle"
30: 6 # "person"
31: 7 # "bicyclist"
32: 8 # "motorcyclist"
40: 9 # "road"
44: 10 # "parking"
48: 11 # "sidewalk"
49: 12 # "other-ground"
50: 13 # "building"
51: 14 # "fence"
52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
60: 9 # "lane-marking" to "road" ---------------------------------mapped
70: 15 # "vegetation"
71: 16 # "trunk"
72: 17 # "terrain"
80: 18 # "pole"
81: 19 # "traffic-sign"
99: 0 # "other-object" to "unlabeled" ----------------------------mapped
252: 1 # "moving-car" to "car" ------------------------------------mapped
253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
254: 6 # "moving-person" to "person" ------------------------------mapped
255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
258: 4 # "moving-truck" to "truck" --------------------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
learning_map_inv: # inverse of previous map
0: 0 # "unlabeled", and others ignored
1: 10 # "car"
2: 11 # "bicycle"
3: 15 # "motorcycle"
4: 18 # "truck"
5: 20 # "other-vehicle"
6: 30 # "person"
7: 31 # "bicyclist"
8: 32 # "motorcyclist"
9: 40 # "road"
10: 44 # "parking"
11: 48 # "sidewalk"
12: 49 # "other-ground"
13: 50 # "building"
14: 51 # "fence"
15: 70 # "vegetation"
16: 71 # "trunk"
17: 72 # "terrain"
18: 80 # "pole"
19: 81 # "traffic-sign"
learning_ignore: # Ignore classes
0: True # "unlabeled", and others ignored
1: False # "car"
2: False # "bicycle"
3: False # "motorcycle"
4: False # "truck"
5: False # "other-vehicle"
6: False # "person"
7: False # "bicyclist"
8: False # "motorcyclist"
9: False # "road"
10: False # "parking"
11: False # "sidewalk"
12: False # "other-ground"
13: False # "building"
14: False # "fence"
15: False # "vegetation"
16: False # "trunk"
17: False # "terrain"
18: False # "pole"
19: False # "traffic-sign"
thing_class: # thing class in panoptic segmentation
0: False # "unlabeled", and others ignored
1: True # "car"
2: True # "bicycle"
3: True # "motorcycle"
4: True # "truck"
5: True # "other-vehicle"
6: True # "person"
7: True # "bicyclist"
8: True # "motorcyclist"
9: False # "road"
10: False # "parking"
11: False # "sidewalk"
12: False # "other-ground"
13: False # "building"
14: False # "fence"
15: False # "vegetation"
16: False # "trunk"
17: False # "terrain"
18: False # "pole"
19: False # "traffic-sign"
split: # sequence numbers
train:
- 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 9
- 10
valid:
- 8
test:
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
================================================
FILE: test_pretrain.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import time
import argparse
import sys
import yaml
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import errno
from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset,collate_fn_BEV_test
from network.instance_post_processing import get_panoptic_segmentation
from utils.eval_pq import PanopticEval
from utils.configs import merge_configs
#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")
def SemKITTI2train(label):
if isinstance(label, list):
return [SemKITTI2train_single(a) for a in label]
else:
return SemKITTI2train_single(label)
def SemKITTI2train_single(label):
return label - 1 # uint8 trick
def main(args):
data_path = args['dataset']['path']
test_batch_size = args['model']['test_batch_size']
pretrained_model = args['model']['pretrained_model']
output_path = args['dataset']['output_path']
compression_model = args['dataset']['grid_size'][2]
grid_size = args['dataset']['grid_size']
visibility = args['model']['visibility']
pytorch_device = torch.device('cuda:0')
if args['model']['polar']:
fea_dim = 9
circular_padding = True
else:
fea_dim = 7
circular_padding = False
# prepare miou fun
unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]
# prepare model
my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)
my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size = grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,
out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)
if os.path.exists(pretrained_model):
my_model.load_state_dict(torch.load(pretrained_model))
pytorch_total_params = sum(p.numel() for p in my_model.parameters())
print('params: ',pytorch_total_params)
my_model.to(pytorch_device)
my_model.eval()
# prepare dataset
test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'test', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
if args['model']['polar']:
test_dataset=spherical_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
else:
test_dataset=voxel_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
test_dataset_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = test_batch_size,
collate_fn = collate_fn_BEV_test,
shuffle = False,
num_workers = 4)
val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,
batch_size = test_batch_size,
collate_fn = collate_fn_BEV,
shuffle = False,
num_workers = 4)
# validation
print('*'*80)
print('Test network performance on validation split')
print('*'*80)
pbar = tqdm(total=len(val_dataset_loader))
time_list = []
pp_time_list = []
evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)
with torch.no_grad():
for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):
val_vox_fea_ten = val_vox_fea.to(pytorch_device)
val_vox_label = SemKITTI2train(val_vox_label)
val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)
val_gt_center_tensor = val_gt_center.to(pytorch_device)
val_gt_offset_tensor = val_gt_offset.to(pytorch_device)
torch.cuda.synchronize()
start_time = time.time()
if visibility:
predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
else:
predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)
torch.cuda.synchronize()
time_list.append(time.time()-start_time)
for count,i_val_grid in enumerate(val_grid):
# get foreground_mask
for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True
# post processing
torch.cuda.synchronize()
start_time = time.time()
panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),val_pt_dataset.thing_list,\
threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
torch.cuda.synchronize()
pp_time_list.append(time.time()-start_time)
panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]
evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))
del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points
pbar.update(1)
class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()
miou,ious = evaluator.getSemIoU()
print('Validation per class PQ, SQ, RQ and IoU: ')
for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):
print('%15s : %6.2f%% %6.2f%% %6.2f%% %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))
pbar.close()
print('Current val PQ is %.3f' %
(class_PQ*100))
print('Current val miou is %.3f'%
(miou*100))
print('Inference time per %d is %.4f seconds\n, postprocessing time is %.4f seconds per scan' %
(test_batch_size,np.mean(time_list),np.mean(pp_time_list)))
# test
print('*'*80)
print('Generate predictions for test split')
print('*'*80)
pbar = tqdm(total=len(test_dataset_loader))
with torch.no_grad():
for i_iter_test,(test_vox_fea,_,_,_,test_grid,_,_,test_pt_fea,test_index) in enumerate(test_dataset_loader):
# predict
test_vox_fea_ten = test_vox_fea.to(pytorch_device)
test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea]
test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid]
if visibility:
predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten,test_vox_fea_ten)
else:
predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten)
# write to label file
for count,i_test_grid in enumerate(test_grid):
# get foreground_mask
for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
for_mask[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] = True
# post processing
panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),test_pt_dataset.thing_list,\
threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
panoptic = panoptic_labels[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]]
save_dir = test_pt_dataset.im_idx[test_index[count]]
_,dir2 = save_dir.split('/sequences/',1)
new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne','predictions')[:-3]+'label'
if not os.path.exists(os.path.dirname(new_save_dir)):
try:
os.makedirs(os.path.dirname(new_save_dir))
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
panoptic.tofile(new_save_dir)
del test_pt_fea_ten,test_grid_ten,test_pt_fea,predict_labels,center,offset
pbar.update(1)
pbar.close()
print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path)
print('Remapping script can be found in semantic-kitti-api.')
if __name__ == '__main__':
# Testing settings
parser = argparse.ArgumentParser(description='')
parser.add_argument('-d', '--data_dir', default='data')
parser.add_argument('-p', '--pretrained_model', default='pretrained_weight/Panoptic_SemKITTI_PolarNet.pt')
parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')
args = parser.parse_args()
with open(args.configs, 'r') as s:
new_args = yaml.safe_load(s)
args = merge_configs(args,new_args)
print(' '.join(sys.argv))
print(args)
main(args)
================================================
FILE: train.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import sys
import numpy as np
import yaml
import torch
import torch.optim as optim
from tqdm import tqdm
from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset
from network.instance_post_processing import get_panoptic_segmentation
from network.loss import panoptic_loss
from utils.eval_pq import PanopticEval
from utils.configs import merge_configs
#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")
def SemKITTI2train(label):
if isinstance(label, list):
return [SemKITTI2train_single(a) for a in label]
else:
return SemKITTI2train_single(label)
def SemKITTI2train_single(label):
return label - 1 # uint8 trick
def load_pretrained_model(model,pretrained_model):
model_dict = model.state_dict()
pretrained_model = {k: v for k, v in pretrained_model.items() if k in model_dict}
model_dict.update(pretrained_model)
model.load_state_dict(model_dict)
return model
def main(args):
data_path = args['dataset']['path']
train_batch_size = args['model']['train_batch_size']
val_batch_size = args['model']['val_batch_size']
check_iter = args['model']['check_iter']
model_save_path = args['model']['model_save_path']
pretrained_model = args['model']['pretrained_model']
compression_model = args['dataset']['grid_size'][2]
grid_size = args['dataset']['grid_size']
visibility = args['model']['visibility']
pytorch_device = torch.device('cuda:0')
if args['model']['polar']:
fea_dim = 9
circular_padding = True
else:
fea_dim = 7
circular_padding = False
#prepare miou fun
unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]
#prepare model
my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)
my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size = grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,
out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)
if os.path.exists(model_save_path):
my_model = load_pretrained_model(my_model,torch.load(model_save_path))
elif os.path.exists(pretrained_model):
my_model = load_pretrained_model(my_model,torch.load(pretrained_model))
my_model.to(pytorch_device)
optimizer = optim.Adam(my_model.parameters())
loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\
center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss'])
#prepare dataset
train_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'train', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
if args['model']['polar']:
train_dataset=spherical_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, use_aug = True)
val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
else:
train_dataset=voxel_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0,use_aug = True)
val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
train_dataset_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = train_batch_size,
collate_fn = collate_fn_BEV,
shuffle = True,
num_workers = 4)
val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,
batch_size = val_batch_size,
collate_fn = collate_fn_BEV,
shuffle = False,
num_workers = 4)
# training
epoch=0
best_val_PQ=0
start_training=False
my_model.train()
global_iter = 0
exce_counter = 0
evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)
while epoch < args['model']['max_epoch']:
pbar = tqdm(total=len(train_dataset_loader))
for i_iter,(train_vox_fea,train_label_tensor,train_gt_center,train_gt_offset,train_grid,_,_,train_pt_fea) in enumerate(train_dataset_loader):
# validation
if global_iter % check_iter == 0:
my_model.eval()
evaluator.reset()
with torch.no_grad():
for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):
val_vox_fea_ten = val_vox_fea.to(pytorch_device)
val_vox_label = SemKITTI2train(val_vox_label)
val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)
val_gt_center_tensor = val_gt_center.to(pytorch_device)
val_gt_offset_tensor = val_gt_offset.to(pytorch_device)
if visibility:
predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
else:
predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)
for count,i_val_grid in enumerate(val_grid):
# get foreground_mask
for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True
# post processing
panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\
val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.int32)
panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]
evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))
del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points
my_model.train()
class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()
miou,ious = evaluator.getSemIoU()
print('Validation per class PQ, SQ, RQ and IoU: ')
for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):
print('%15s : %6.2f%% %6.2f%% %6.2f%% %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))
# save model if performance is improved
if best_val_PQ=args['model']['SAP']['start_epoch']:
for fea in train_pt_fea_ten:
fea.requires_grad_()
# forward
if visibility:
sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)
else:
sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)
# loss
loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)
# self adversarial pruning
if args['model']['enable_SAP'] and epoch>=args['model']['SAP']['start_epoch']:
loss.backward()
for i,fea in enumerate(train_pt_fea_ten):
fea_grad = torch.norm(fea.grad,dim=1)
top_k_grad, _ = torch.topk(fea_grad, int(args['model']['SAP']['rate']*fea_grad.shape[0]))
# delete high influential points
train_pt_fea_ten[i] = train_pt_fea_ten[i][fea_grad < top_k_grad[-1]]
train_grid_ten[i] = train_grid_ten[i][fea_grad < top_k_grad[-1]]
optimizer.zero_grad()
# second pass
# forward
if visibility:
sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)
else:
sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)
# loss
loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)
# backward + optimize
loss.backward()
optimizer.step()
except Exception as error:
if exce_counter == 0:
print(error)
exce_counter += 1
# zero the parameter gradients
optimizer.zero_grad()
pbar.update(1)
start_training=True
global_iter += 1
pbar.close()
epoch += 1
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='')
parser.add_argument('-d', '--data_dir', default='data')
parser.add_argument('-p', '--model_save_path', default='./Panoptic_SemKITTI.pt')
parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')
parser.add_argument('--pretrained_model', default='empty')
args = parser.parse_args()
with open(args.configs, 'r') as s:
new_args = yaml.safe_load(s)
args = merge_configs(args,new_args)
print(' '.join(sys.argv))
print(args)
main(args)
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/configs.py
================================================
#!/usr/bin/env python3
import yaml
def merge_configs(cfgs,new_cfgs):
if hasattr(cfgs, 'data_dir'):
new_cfgs['dataset']['path']=cfgs.data_dir
if hasattr(cfgs, 'model_save_path'):
new_cfgs['model']['model_save_path']=cfgs.model_save_path
if hasattr(cfgs, 'pretrained_model'):
new_cfgs['model']['pretrained_model']=cfgs.pretrained_model
return new_cfgs
================================================
FILE: utils/eval_pq.py
================================================
#!/usr/bin/env python3
import numpy as np
import time
class PanopticEval:
""" Panoptic evaluation using numpy
authors: Andres Milioto and Jens Behley
"""
def __init__(self, n_classes, device=None, ignore=None, offset=2**32, min_points=30):
self.n_classes = n_classes
assert (device == None)
self.ignore = np.array(ignore, dtype=np.int64)
self.include = np.array([n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64)
print("[PANOPTIC EVAL] IGNORE: ", self.ignore)
print("[PANOPTIC EVAL] INCLUDE: ", self.include)
self.reset()
self.offset = offset # largest number of instances in a given scan
self.min_points = min_points # smallest number of points to consider instances in gt
self.eps = 1e-15
def num_classes(self):
return self.n_classes
def reset(self):
# general things
# iou stuff
self.px_iou_conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64)
# panoptic stuff
self.pan_tp = np.zeros(self.n_classes, dtype=np.int64)
self.pan_iou = np.zeros(self.n_classes, dtype=np.double)
self.pan_fp = np.zeros(self.n_classes, dtype=np.int64)
self.pan_fn = np.zeros(self.n_classes, dtype=np.int64)
################################# IoU STUFF ##################################
def addBatchSemIoU(self, x_sem, y_sem):
# idxs are labels and predictions
idxs = np.stack([x_sem, y_sem], axis=0)
# make confusion matrix (cols = gt, rows = pred)
np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1)
def getSemIoUStats(self):
# clone to avoid modifying the real deal
conf = self.px_iou_conf_matrix.copy().astype(np.double)
# remove fp from confusion on the ignore classes predictions
# points that were predicted of another class, but were ignore
# (corresponds to zeroing the cols of those classes, since the predictions
# go on the rows)
conf[:, self.ignore] = 0
# get the clean stats
tp = conf.diagonal()
fp = conf.sum(axis=1) - tp
fn = conf.sum(axis=0) - tp
return tp, fp, fn
def getSemIoU(self):
tp, fp, fn = self.getSemIoUStats()
# print(f"tp={tp}")
# print(f"fp={fp}")
# print(f"fn={fn}")
intersection = tp
union = tp + fp + fn
union = np.maximum(union, self.eps)
iou = intersection.astype(np.double) / union.astype(np.double)
iou_mean = (intersection[self.include].astype(np.double) / union[self.include].astype(np.double)).mean()
return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES
def getSemAcc(self):
tp, fp, fn = self.getSemIoUStats()
total_tp = tp.sum()
total = tp[self.include].sum() + fp[self.include].sum()
total = np.maximum(total, self.eps)
acc_mean = total_tp.astype(np.double) / total.astype(np.double)
return acc_mean # returns "acc mean"
################################# IoU STUFF ##################################
##############################################################################
############################# Panoptic STUFF ################################
def addBatchPanoptic(self, x_sem_row, x_inst_row, y_sem_row, y_inst_row):
# make sure instances are not zeros (it messes with my approach)
x_inst_row = x_inst_row + 1
y_inst_row = y_inst_row + 1
# only interested in points that are outside the void area (not in excluded classes)
for cl in self.ignore:
# make a mask for this class
gt_not_in_excl_mask = y_sem_row != cl
# remove all other points
x_sem_row = x_sem_row[gt_not_in_excl_mask]
y_sem_row = y_sem_row[gt_not_in_excl_mask]
x_inst_row = x_inst_row[gt_not_in_excl_mask]
y_inst_row = y_inst_row[gt_not_in_excl_mask]
# first step is to count intersections > 0.5 IoU for each class (except the ignored ones)
for cl in self.include:
# print("*"*80)
# print("CLASS", cl.item())
# get a class mask
x_inst_in_cl_mask = x_sem_row == cl
y_inst_in_cl_mask = y_sem_row == cl
# get instance points in class (makes outside stuff 0)
x_inst_in_cl = x_inst_row * x_inst_in_cl_mask.astype(np.int64)
y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64)
# generate the areas for each unique instance prediction
unique_pred, counts_pred = np.unique(x_inst_in_cl[x_inst_in_cl > 0], return_counts=True)
id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}
matched_pred = np.array([False] * unique_pred.shape[0])
# print("Unique predictions:", unique_pred)
# generate the areas for each unique instance gt_np
unique_gt, counts_gt = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True)
id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}
matched_gt = np.array([False] * unique_gt.shape[0])
# print("Unique ground truth:", unique_gt)
# generate intersection using offset
valid_combos = np.logical_and(x_inst_in_cl > 0, y_inst_in_cl > 0)
offset_combo = x_inst_in_cl[valid_combos] + self.offset * y_inst_in_cl[valid_combos]
unique_combo, counts_combo = np.unique(offset_combo, return_counts=True)
# generate an intersection map
# count the intersections with over 0.5 IoU as TP
gt_labels = unique_combo // self.offset
pred_labels = unique_combo % self.offset
gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])
pred_areas = np.array([counts_pred[id2idx_pred[id]] for id in pred_labels])
intersections = counts_combo
unions = gt_areas + pred_areas - intersections
ious = intersections.astype(np.float) / unions.astype(np.float)
tp_indexes = ious > 0.5
self.pan_tp[cl] += np.sum(tp_indexes)
self.pan_iou[cl] += np.sum(ious[tp_indexes])
matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True
matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True
# count the FN
self.pan_fn[cl] += np.sum(np.logical_and(counts_gt >= self.min_points, matched_gt == False))
# count the FP
self.pan_fp[cl] += np.sum(np.logical_and(counts_pred >= self.min_points, matched_pred == False))
def getPQ(self):
# first calculate for all classes
sq_all = self.pan_iou.astype(np.double) / np.maximum(self.pan_tp.astype(np.double), self.eps)
rq_all = self.pan_tp.astype(np.double) / np.maximum(
self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double),
self.eps)
pq_all = sq_all * rq_all
# then do the REAL mean (no ignored classes)
SQ = sq_all[self.include].mean()
RQ = rq_all[self.include].mean()
PQ = pq_all[self.include].mean()
return PQ, SQ, RQ, pq_all, sq_all, rq_all
############################# Panoptic STUFF ################################
##############################################################################
def addBatch(self, x_sem, x_inst, y_sem, y_inst): # x=preds, y=targets
''' IMPORTANT: Inputs must be batched. Either [N,H,W], or [N, P]
'''
# add to IoU calculation (for checking purposes)
self.addBatchSemIoU(x_sem, y_sem)
# now do the panoptic stuff
self.addBatchPanoptic(x_sem, x_inst, y_sem, y_inst)
if __name__ == "__main__":
# generate problem from He paper (https://arxiv.org/pdf/1801.00868.pdf)
classes = 5 # ignore, grass, sky, person, dog
cl_strings = ["ignore", "grass", "sky", "person", "dog"]
ignore = [0] # only ignore ignore class
min_points = 1 # for this example we care about all points
# generate ground truth and prediction
sem_pred = []
inst_pred = []
sem_gt = []
inst_gt = []
# some ignore stuff
N_ignore = 50
sem_pred.extend([0 for i in range(N_ignore)])
inst_pred.extend([0 for i in range(N_ignore)])
sem_gt.extend([0 for i in range(N_ignore)])
inst_gt.extend([0 for i in range(N_ignore)])
# grass segment
N_grass = 50
N_grass_pred = 40 # rest is sky
sem_pred.extend([1 for i in range(N_grass_pred)]) # grass
sem_pred.extend([2 for i in range(N_grass - N_grass_pred)]) # sky
inst_pred.extend([0 for i in range(N_grass)])
sem_gt.extend([1 for i in range(N_grass)]) # grass
inst_gt.extend([0 for i in range(N_grass)])
# sky segment
N_sky = 50
N_sky_pred = 40 # rest is grass
sem_pred.extend([2 for i in range(N_sky_pred)]) # sky
sem_pred.extend([1 for i in range(N_sky - N_sky_pred)]) # grass
inst_pred.extend([0 for i in range(N_sky)]) # first instance
sem_gt.extend([2 for i in range(N_sky)]) # sky
inst_gt.extend([0 for i in range(N_sky)]) # first instance
# wrong dog as person prediction
N_dog = 50
N_person = N_dog
sem_pred.extend([3 for i in range(N_person)])
inst_pred.extend([35 for i in range(N_person)])
sem_gt.extend([4 for i in range(N_dog)])
inst_gt.extend([22 for i in range(N_dog)])
# two persons in prediction, but three in gt
N_person = 50
sem_pred.extend([3 for i in range(6 * N_person)])
inst_pred.extend([8 for i in range(4 * N_person)])
inst_pred.extend([95 for i in range(2 * N_person)])
sem_gt.extend([3 for i in range(6 * N_person)])
inst_gt.extend([33 for i in range(3 * N_person)])
inst_gt.extend([42 for i in range(N_person)])
inst_gt.extend([11 for i in range(2 * N_person)])
# gt and pred to numpy
sem_pred = np.array(sem_pred, dtype=np.int64).reshape(1, -1)
inst_pred = np.array(inst_pred, dtype=np.int64).reshape(1, -1)
sem_gt = np.array(sem_gt, dtype=np.int64).reshape(1, -1)
inst_gt = np.array(inst_gt, dtype=np.int64).reshape(1, -1)
# evaluator
evaluator = PanopticEval(classes, ignore=ignore, min_points=1)
evaluator.addBatch(sem_pred, inst_pred, sem_gt, inst_gt)
pq, sq, rq, all_pq, all_sq, all_rq = evaluator.getPQ()
iou, all_iou = evaluator.getSemIoU()
# [PANOPTIC EVAL] IGNORE: [0]
# [PANOPTIC EVAL] INCLUDE: [1 2 3 4]
# TOTALS
# PQ: 0.47916666666666663
# SQ: 0.5520833333333333
# RQ: 0.6666666666666666
# IoU: 0.5476190476190476
# Class ignore PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0
# Class grass PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666
# Class sky PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666
# Class person PQ: 0.5833333333333333 SQ: 0.875 RQ: 0.6666666666666666 IoU: 0.8571428571428571
# Class dog PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0
print("TOTALS")
print("PQ:", pq.item(), pq.item() == 0.47916666666666663)
print("SQ:", sq.item(), sq.item() == 0.5520833333333333)
print("RQ:", rq.item(), rq.item() == 0.6666666666666666)
print("IoU:", iou.item(), iou.item() == 0.5476190476190476)
for i, (pq, sq, rq, iou) in enumerate(zip(all_pq, all_sq, all_rq, all_iou)):
print("Class", cl_strings[i], "\t", "PQ:", pq.item(), "SQ:", sq.item(), "RQ:", rq.item(), "IoU:", iou.item())
================================================
FILE: utils/visual.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import numpy as np
def flow_to_img(flow, normalize=True):
"""Convert flow to viewable image, using color hue to encode flow vector orientation, and color saturation to
encode vector length. This is similar to the OpenCV tutorial on dense optical flow, except that they map vector
length to the value plane of the HSV color model, instead of the saturation plane, as we do here.
Args:
flow: optical flow
normalize: Normalize flow to 0..255
Returns:
img: viewable representation of the dense optical flow in RGB format
Ref:
https://github.com/philferriere/tfoptflow/blob/33e8a701e34c8ce061f17297d40619afbd459ade/tfoptflow/optflow.py
"""
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
flow_magnitude, flow_angle = cv2.cartToPolar(flow[..., 0].astype(np.float32), flow[..., 1].astype(np.float32))
# A couple times, we've gotten NaNs out of the above...
nans = np.isnan(flow_magnitude)
if np.any(nans):
nans = np.where(nans)
flow_magnitude[nans] = 0.
# Normalize
hsv[..., 0] = flow_angle * 180 / np.pi / 2
if normalize is True:
hsv[..., 1] = cv2.normalize(flow_magnitude, None, 0, 255, cv2.NORM_MINMAX)
else:
hsv[..., 1] = flow_magnitude
hsv[..., 2] = 255
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return img