main 3a72f2380a4e cached
24 files
52.7 MB
34.2k tokens
111 symbols
1 requests
Download .txt
Repository: edwardzhou130/Panoptic-PolarNet
Branch: main
Commit: 3a72f2380a4e
Files: 24
Total size: 52.7 MB

Directory structure:
gitextract_pewizc2i/

├── LICENSE
├── README.md
├── configs/
│   └── SemanticKITTI_model/
│       └── Panoptic-PolarNet.yaml
├── data/
│   └── README.md
├── dataloader/
│   ├── __init__.py
│   ├── dataset.py
│   ├── instance_augmentation.py
│   └── process_panoptic.py
├── instance_preprocess.py
├── network/
│   ├── BEV_Unet.py
│   ├── __init__.py
│   ├── instance_post_processing.py
│   ├── loss.py
│   ├── lovasz_losses.py
│   └── ptBEV.py
├── pretrained_weight/
│   └── Panoptic_SemKITTI_PolarNet.pt
├── requirements.txt
├── semantic-kitti.yaml
├── test_pretrain.py
├── train.py
└── utils/
    ├── __init__.py
    ├── configs.py
    ├── eval_pq.py
    └── visual.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
BSD 3-Clause License

Copyright (c) 2020, Zixiang Zhou
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
   contributors may be used to endorse or promote products derived from
   this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# Panoptic-PolarNet
This is the official implementation of Panoptic-PolarNet.

[[**ArXiv paper**]](https://arxiv.org/abs/2103.14962)

<p align="center">
        <img src="imgs/Visualization.gif" width="60%"> 
</p>

# Introduction

Panoptic-PolarNet is a fast and robust LiDAR point cloud panoptic segmentation framework. We learn both semantic segmentation and class-agnostic instance clustering in a single inference network using a polar Bird's Eye View (BEV) representation. Predictions from the semantic and instance head are then fused through a majority voting to create the final panopticsegmentation.

<p align="center">
        <img src="imgs/CVPR_pipeline.png" width="100%"> 
</p>

We test Panoptic-PolarNet on SemanticKITTI and nuScenes datasets. Experiment shows that Panoptic-PolarNet reaches state-of-the-art performances with a real-time inference speed.

<p align="center">
        <img src="imgs/result.png" width="100%"> 
</p>


## Prepare dataset and environment

This code is tested on Ubuntu 16.04 with Python 3.8, CUDA 10.2 and Pytorch 1.7.0.

1, Install the following dependencies by either `pip install -r requirements.txt` or manual installation.
* numpy
* pytorch
* tqdm
* yaml
* Cython
* [numba](https://github.com/numba/numba)
* [torch-scatter](https://github.com/rusty1s/pytorch_scatter)
* [dropblock](https://github.com/miguelvr/dropblock)
* (Optional) [open3d](https://github.com/intel-isl/Open3D)

2, Download Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).

3, Extract everything into the same folder. The folder structure inside the zip files of label data matches the folder structure of the LiDAR point cloud data.

4, Data file structure should look like this:

```
./
├── train.py
├── ...
└── data/
    ├──sequences
        ├── 00/           
        │   ├── velodyne/	# Unzip from KITTI Odometry Benchmark Velodyne point clouds.
        |   |	├── 000000.bin
        |   |	├── 000001.bin
        |   |	└── ...
        │   └── labels/ 	# Unzip from SemanticKITTI label data.
        |       ├── 000000.label
        |       ├── 000001.label
        |       └── ...
        ├── ...
        └── 21/
	    └── ...
```

5, Instance preprocessing:
```shell
python instance_preprocess.py -d </your data path> -o </preprocessed file output path>
``` 

## Training

Run
```shell
python train.py
```

The code will automatically train, validate and save the model that has the best validation PQ. 

Panoptic-PolarNet with default setting requires around 11GB GPU memory for the training. Training model on GPU with less memory would likely cause GPU out-of-memory. In this case, you can set the ``grid_size`` in the config file to ``[320,240,32]`` or lower.

## Evaluate our pretrained model

We also provide a pretrained Panoptic-PolarNet weight.
```shell
python test_pretrain.py
```
Result will be stored in `./out` folder. Test performance can be evaluated by uploading label results onto the SemanticKITTI competition website [here](https://competitions.codalab.org/competitions/24025).

## Citation
Please cite our paper if this code benefits your research:
```
@inproceedings{Zhou2021PanopticPolarNet,
author={Zhou, Zixiang and Zhang, Yang and Foroosh, Hassan},
title={Panoptic-PolarNet: Proposal-free LiDAR Point Cloud Panoptic Segmentation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2021}
}

@InProceedings{Zhang_2020_CVPR,
author = {Zhang, Yang and Zhou, Zixiang and David, Philip and Yue, Xiangyu and Xi, Zerong and Gong, Boqing and Foroosh, Hassan},
title = {PolarNet: An Improved Grid Representation for Online LiDAR Point Clouds Semantic Segmentation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2020}
}
```

================================================
FILE: configs/SemanticKITTI_model/Panoptic-PolarNet.yaml
================================================
model_name: Panoptic_PolarNet

dataset:
    name: semantickitti
    path: data
    output_path: out/SemKITTI
    instance_pkl_path: data
    rotate_aug: True
    flip_aug: True
    inst_aug: True
    inst_aug_type:
        inst_os: True
        inst_loc_aug: True
        inst_global_aug: True
    gt_generator:
        sigma: 5
    grid_size: [480,360,32]

model:
    model_save_path: ./Panoptic_SemKITTI.pt
    pretrained_model: /pretrained_weight/Panoptic_SemKITTI_PolarNet.pt
    polar: True
    visibility: True
    
    train_batch_size: 2
    val_batch_size: 2
    test_batch_size: 1
    check_iter: 4000
    max_epoch: 100
    post_proc:
        threshold: 0.1
        nms_kernel: 5
        top_k: 100
    center_loss: MSE
    offset_loss: L1
    center_loss_weight: 100
    offset_loss_weight: 10
    enable_SAP: True
    SAP:
        start_epoch: 30
        rate: 0.01

================================================
FILE: data/README.md
================================================
# PolarSeg

Download Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).


================================================
FILE: dataloader/__init__.py
================================================


================================================
FILE: dataloader/dataset.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SemKITTI dataloader
"""
import os
import numpy as np
import torch
import random
import time
import numba as nb
import yaml
import pickle
import errno
from torch.utils import data

from .process_panoptic import PanopticLabelGenerator
from .instance_augmentation import instance_augmentation

class SemKITTI(data.Dataset):
    def __init__(self, data_path, imageset = 'train', return_ref = False, instance_pkl_path ='data'):
        self.return_ref = return_ref
        with open("semantic-kitti.yaml", 'r') as stream:
            semkittiyaml = yaml.safe_load(stream)
        self.learning_map = semkittiyaml['learning_map']
        thing_class = semkittiyaml['thing_class']
        self.thing_list = [cl for cl, ignored in thing_class.items() if ignored]
        self.imageset = imageset
        if imageset == 'train':
            split = semkittiyaml['split']['train']
        elif imageset == 'val':
            split = semkittiyaml['split']['valid']
        elif imageset == 'test':
            split = semkittiyaml['split']['test']
        else:
            raise Exception('Split must be train/val/test')
        
        self.im_idx = []
        for i_folder in split:
            self.im_idx += absoluteFilePaths('/'.join([data_path,str(i_folder).zfill(2),'velodyne']))
        self.im_idx.sort()

        # get class distribution weight 
        epsilon_w = 0.001
        origin_class = semkittiyaml['content'].keys()
        weights = np.zeros((len(semkittiyaml['learning_map_inv'])-1,),dtype = np.float32)
        for class_num in origin_class:
            if semkittiyaml['learning_map'][class_num] != 0:
                weights[semkittiyaml['learning_map'][class_num]-1] += semkittiyaml['content'][class_num]
        self.CLS_LOSS_WEIGHT = 1/(weights + epsilon_w)
        self.instance_pkl_path = instance_pkl_path
         
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.im_idx)
    
    def __getitem__(self, index):
        raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
        if self.imageset == 'test':
            sem_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=int),axis=1)
            inst_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=np.uint32),axis=1)
        else:
            annotated_data = np.fromfile(self.im_idx[index].replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))
            sem_data = annotated_data & 0xFFFF #delete high 16 digits binary
            sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)
            inst_data = annotated_data
        data_tuple = (raw_data[:,:3], sem_data.astype(np.uint8),inst_data)
        if self.return_ref:
            data_tuple += (raw_data[:,3],)
        return data_tuple

    def save_instance(self, out_dir, min_points = 10):
        'instance data preparation'
        instance_dict={label:[] for label in self.thing_list}
        for data_path in self.im_idx:
            print('process instance for:'+data_path)
            # get x,y,z,ref,semantic label and instance label
            raw_data = np.fromfile(data_path, dtype=np.float32).reshape((-1, 4))
            annotated_data = np.fromfile(data_path.replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))
            sem_data = annotated_data & 0xFFFF #delete high 16 digits binary
            sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)
            inst_data = annotated_data

            # instance mask
            mask = np.zeros_like(sem_data,dtype=bool)
            for label in self.thing_list:
                mask[sem_data == label] = True

            # create unqiue instance list
            inst_label = inst_data[mask].squeeze()
            unique_label = np.unique(inst_label)
            num_inst = len(unique_label)

            inst_count = 0
            for inst in unique_label:
                # get instance index
                index = np.where(inst_data == inst)[0]
                # get semantic label
                class_label = sem_data[index[0]]
                # skip small instance
                if index.size<min_points: continue
                # save
                _,dir2 = data_path.split('/sequences/',1)
                new_save_dir = out_dir + '/sequences/' +dir2.replace('velodyne','instance')[:-4]+'_'+str(inst_count)+'.bin'
                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                inst_fea = raw_data[index]
                inst_fea.tofile(new_save_dir)
                instance_dict[int(class_label)].append(new_save_dir)
                inst_count+=1
        with open(out_dir+'/instance_path.pkl', 'wb') as f:
            pickle.dump(instance_dict, f)

def absoluteFilePaths(directory):
   for dirpath,_,filenames in os.walk(directory):
       for f in filenames:
           yield os.path.abspath(os.path.join(dirpath, f))

class voxel_dataset(data.Dataset):
  def __init__(self, in_dataset, args, grid_size, ignore_label = 0, return_test = False, fixed_volume_space= True, use_aug = False, max_volume_space = [50,50,1.5], min_volume_space = [-50,-50,-3]):
        'Initialization'
        self.point_cloud_dataset = in_dataset
        self.grid_size = np.asarray(grid_size)
        self.rotate_aug = args['rotate_aug'] if use_aug else False
        self.ignore_label = ignore_label
        self.return_test = return_test
        self.flip_aug = args['flip_aug'] if use_aug else False
        self.instance_aug = args['inst_aug'] if use_aug else False
        self.fixed_volume_space = fixed_volume_space
        self.max_volume_space = max_volume_space
        self.min_volume_space = min_volume_space

        self.panoptic_proc = PanopticLabelGenerator(self.grid_size,sigma=args['gt_generator']['sigma'])
        if self.instance_aug:
            self.inst_aug = instance_augmentation(self.point_cloud_dataset.instance_pkl_path+'/instance_path.pkl',self.point_cloud_dataset.thing_list,self.point_cloud_dataset.CLS_LOSS_WEIGHT,\
                                                random_flip=args['inst_aug_type']['inst_global_aug'],random_add=args['inst_aug_type']['inst_os'],\
                                                random_rotate=args['inst_aug_type']['inst_global_aug'],local_transformation=args['inst_aug_type']['inst_loc_aug'])

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.point_cloud_dataset)

  def __getitem__(self, index):
        'Generates one sample of data'
        data = self.point_cloud_dataset[index]
        if len(data) == 3:
            xyz,labels,insts = data
        elif len(data) == 4:
            xyz,labels,insts,feat = data
            if len(feat.shape) == 1: feat = feat[..., np.newaxis]
        else: raise Exception('Return invalid data tuple')
        if len(labels.shape) == 1: labels = labels[..., np.newaxis]
        if len(insts.shape) == 1: insts = insts[..., np.newaxis]
        
        # random data augmentation by rotation
        if self.rotate_aug:
            rotate_rad = np.deg2rad(np.random.random()*360)
            c, s = np.cos(rotate_rad), np.sin(rotate_rad)
            j = np.matrix([[c, s], [-s, c]])
            xyz[:,:2] = np.dot( xyz[:,:2],j)

        # random data augmentation by flip x , y or x+y
        if self.flip_aug:
            flip_type = np.random.choice(4,1)
            if flip_type==1:
                xyz[:,0] = -xyz[:,0]
            elif flip_type==2:
                xyz[:,1] = -xyz[:,1]
            elif flip_type==3:
                xyz[:,:2] = -xyz[:,:2]

        # random instance augmentation
        if self.instance_aug:
            xyz,labels,insts,feat = self.inst_aug.instance_aug(xyz,labels.squeeze(),insts.squeeze(),feat)

        max_bound = np.percentile(xyz,100,axis = 0)
        min_bound = np.percentile(xyz,0,axis = 0)
        
        if self.fixed_volume_space:
            max_bound = np.asarray(self.max_volume_space)
            min_bound = np.asarray(self.min_volume_space)

        # get grid index
        crop_range = max_bound - min_bound
        cur_grid_size = self.grid_size
        
        intervals = crop_range/(cur_grid_size-1)
        if (intervals==0).any(): print("Zero interval!")
        
        grid_ind = (np.floor((np.clip(xyz,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)

        # process voxel position
        voxel_position = np.zeros(self.grid_size,dtype = np.float32)
        dim_array = np.ones(len(self.grid_size)+1,int)
        dim_array[0] = -1 
        voxel_position = (np.indices(self.grid_size) + 0.5)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)

        # process labels
        processed_label = np.ones(self.grid_size,dtype = np.uint8)*self.ignore_label
        label_voxel_pair = np.concatenate([grid_ind,labels],axis = 1)
        label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:,0],grid_ind[:,1],grid_ind[:,2])),:]
        processed_label = nb_process_label(np.copy(processed_label),label_voxel_pair)

        # get thing points mask
        mask = np.zeros_like(labels,dtype=bool)
        for label in self.point_cloud_dataset.thing_list:
            mask[labels == label] = True
        
        inst_label = insts[mask].squeeze()
        unique_label = np.unique(inst_label)
        unique_label_dict = {label:idx+1 for idx , label in enumerate(unique_label)}
        if inst_label.size > 1:            
            inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)
            
            # process panoptic
            processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label
            inst_voxel_pair = np.concatenate([grid_ind[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)
            inst_voxel_pair = inst_voxel_pair[np.lexsort((grid_ind[mask[:,0],0],grid_ind[mask[:,0],1])),:]
            processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)
        else:
            processed_inst = None

        center,center_points,offset = self.panoptic_proc(insts[mask],xyz[mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)
        
        data_tuple = (voxel_position,processed_label,center,offset)

        # center data on each voxel for PTnet
        voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
        return_xyz = xyz - voxel_centers
        return_xyz = np.concatenate((return_xyz,xyz),axis = 1)
        
        if len(data) == 3:
            return_fea = return_xyz
        elif len(data) == 4:
            return_fea = np.concatenate((return_xyz,feat),axis = 1)
        
        if self.return_test:
            data_tuple += (grid_ind,labels,insts,return_fea,index)
        else:
            data_tuple += (grid_ind,labels,insts,return_fea)
        return data_tuple

# transformation between Cartesian coordinates and polar coordinates
def cart2polar(input_xyz):
    rho = np.sqrt(input_xyz[:,0]**2 + input_xyz[:,1]**2)
    phi = np.arctan2(input_xyz[:,1],input_xyz[:,0])
    return np.stack((rho,phi,input_xyz[:,2]),axis=1)

def polar2cat(input_xyz_polar):
    x = input_xyz_polar[0]*np.cos(input_xyz_polar[1])
    y = input_xyz_polar[0]*np.sin(input_xyz_polar[1])
    return np.stack((x,y,input_xyz_polar[2]),axis=0)

class spherical_dataset(data.Dataset):
  def __init__(self, in_dataset, args, grid_size, ignore_label = 0, return_test = False, use_aug = False, fixed_volume_space= True, max_volume_space = [50,np.pi,1.5], min_volume_space = [3,-np.pi,-3]):
        'Initialization'
        self.point_cloud_dataset = in_dataset
        self.grid_size = np.asarray(grid_size)
        self.rotate_aug = args['rotate_aug'] if use_aug else False
        self.flip_aug = args['flip_aug'] if use_aug else False
        self.instance_aug = args['inst_aug'] if use_aug else False
        self.ignore_label = ignore_label
        self.return_test = return_test
        self.fixed_volume_space = fixed_volume_space
        self.max_volume_space = max_volume_space
        self.min_volume_space = min_volume_space

        self.panoptic_proc = PanopticLabelGenerator(self.grid_size,sigma=args['gt_generator']['sigma'],polar=True)
        if self.instance_aug:
            self.inst_aug = instance_augmentation(self.point_cloud_dataset.instance_pkl_path+'/instance_path.pkl',self.point_cloud_dataset.thing_list,self.point_cloud_dataset.CLS_LOSS_WEIGHT,\
                                                random_flip=args['inst_aug_type']['inst_global_aug'],random_add=args['inst_aug_type']['inst_os'],\
                                                random_rotate=args['inst_aug_type']['inst_global_aug'],local_transformation=args['inst_aug_type']['inst_loc_aug'])

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.point_cloud_dataset)

  def __getitem__(self, index):
        'Generates one sample of data'
        data = self.point_cloud_dataset[index]
        if len(data) == 3:
            xyz,labels,insts = data
        elif len(data) == 4:
            xyz,labels,insts,feat = data
            if len(feat.shape) == 1: feat = feat[..., np.newaxis]
        else: raise Exception('Return invalid data tuple')
        if len(labels.shape) == 1: labels = labels[..., np.newaxis]
        if len(insts.shape) == 1: insts = insts[..., np.newaxis]
        
        # random data augmentation by rotation
        if self.rotate_aug:
            rotate_rad = np.deg2rad(np.random.random()*360)
            c, s = np.cos(rotate_rad), np.sin(rotate_rad)
            j = np.matrix([[c, s], [-s, c]])
            xyz[:,:2] = np.dot( xyz[:,:2],j)

        # random data augmentation by flip x , y or x+y
        if self.flip_aug:
            flip_type = np.random.choice(4,1)
            if flip_type==1:
                xyz[:,0] = -xyz[:,0]
            elif flip_type==2:
                xyz[:,1] = -xyz[:,1]
            elif flip_type==3:
                xyz[:,:2] = -xyz[:,:2]

        # random instance augmentation
        if self.instance_aug:
            xyz,labels,insts,feat = self.inst_aug.instance_aug(xyz,labels.squeeze(),insts.squeeze(),feat)
        
        # convert coordinate into polar coordinates
        xyz_pol = cart2polar(xyz)
        
        max_bound_r = np.percentile(xyz_pol[:,0],100,axis = 0)
        min_bound_r = np.percentile(xyz_pol[:,0],0,axis = 0)
        max_bound = np.max(xyz_pol[:,1:],axis = 0)
        min_bound = np.min(xyz_pol[:,1:],axis = 0)
        max_bound = np.concatenate(([max_bound_r],max_bound))
        min_bound = np.concatenate(([min_bound_r],min_bound))
        if self.fixed_volume_space:
            max_bound = np.asarray(self.max_volume_space)
            min_bound = np.asarray(self.min_volume_space)

        # get grid index
        crop_range = max_bound - min_bound
        cur_grid_size = self.grid_size
        intervals = crop_range/(cur_grid_size-1)

        if (intervals==0).any(): print("Zero interval!")
        grid_ind = (np.floor((np.clip(xyz_pol,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)

        current_grid = grid_ind[:np.size(labels)]

        # process voxel position
        voxel_position = np.zeros(self.grid_size,dtype = np.float32)
        dim_array = np.ones(len(self.grid_size)+1,int)
        dim_array[0] = -1 
        voxel_position = np.indices(self.grid_size)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)
        # voxel_position = polar2cat(voxel_position)
        
        # process labels
        processed_label = np.ones(self.grid_size,dtype = np.uint8)*self.ignore_label
        label_voxel_pair = np.concatenate([current_grid,labels],axis = 1)
        label_voxel_pair = label_voxel_pair[np.lexsort((current_grid[:,0],current_grid[:,1],current_grid[:,2])),:]
        processed_label = nb_process_label(np.copy(processed_label),label_voxel_pair)
        # data_tuple = (voxel_position,processed_label)

        # get thing points mask
        mask = np.zeros_like(labels,dtype=bool)
        for label in self.point_cloud_dataset.thing_list:
            mask[labels == label] = True
        
        inst_label = insts[mask].squeeze()
        unique_label = np.unique(inst_label)
        unique_label_dict = {label:idx+1 for idx , label in enumerate(unique_label)}
        if inst_label.size > 1:            
            inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)
            
            # process panoptic
            processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label
            inst_voxel_pair = np.concatenate([current_grid[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)
            inst_voxel_pair = inst_voxel_pair[np.lexsort((current_grid[mask[:,0],0],current_grid[mask[:,0],1])),:]
            processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)
        else:
            processed_inst = None

        center,center_points,offset = self.panoptic_proc(insts[mask],xyz[:np.size(labels)][mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)

        # prepare visiblity feature
        # find max distance index in each angle,height pair
        valid_label = np.zeros_like(processed_label,dtype=bool)
        valid_label[current_grid[:,0],current_grid[:,1],current_grid[:,2]] = True
        valid_label = valid_label[::-1]
        max_distance_index = np.argmax(valid_label,axis=0)
        max_distance = max_bound[0]-intervals[0]*(max_distance_index)
        distance_feature = np.expand_dims(max_distance, axis=2)-np.transpose(voxel_position[0],(1,2,0))
        distance_feature = np.transpose(distance_feature,(1,2,0))
        # convert to boolean feature
        distance_feature = (distance_feature>0)*-1.
        distance_feature[current_grid[:,2],current_grid[:,0],current_grid[:,1]]=1.

        data_tuple = (distance_feature,processed_label,center,offset)

        # center data on each voxel for PTnet
        voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound
        return_xyz = xyz_pol - voxel_centers
        return_xyz = np.concatenate((return_xyz,xyz_pol,xyz[:,:2]),axis = 1)

        if len(data) == 3:
            return_fea = return_xyz
        elif len(data) == 4:
            return_fea = np.concatenate((return_xyz,feat),axis = 1)
        
        if self.return_test:
            data_tuple += (grid_ind,labels,insts,return_fea,index)
        else:
            data_tuple += (grid_ind,labels,insts,return_fea)
        return data_tuple
    
@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])',nopython=True,cache=True,parallel = False)
def nb_process_label(processed_label,sorted_label_voxel_pair):
    label_size = 256
    counter = np.zeros((label_size,),dtype = np.uint16)
    counter[sorted_label_voxel_pair[0,3]] = 1
    cur_sear_ind = sorted_label_voxel_pair[0,:3]
    for i in range(1,sorted_label_voxel_pair.shape[0]):
        cur_ind = sorted_label_voxel_pair[i,:3]
        if not np.all(np.equal(cur_ind,cur_sear_ind)):
            processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)
            counter = np.zeros((label_size,),dtype = np.uint16)
            cur_sear_ind = cur_ind
        counter[sorted_label_voxel_pair[i,3]] += 1
    processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)
    return processed_label

@nb.jit('u1[:,:](u1[:,:],i8[:,:])',nopython=True,cache=True,parallel = False)
def nb_process_inst(processed_inst,sorted_inst_voxel_pair):
    label_size = 256
    counter = np.zeros((label_size,),dtype = np.uint16)
    counter[sorted_inst_voxel_pair[0,2]] = 1
    cur_sear_ind = sorted_inst_voxel_pair[0,:2]
    for i in range(1,sorted_inst_voxel_pair.shape[0]):
        cur_ind = sorted_inst_voxel_pair[i,:2]
        if not np.all(np.equal(cur_ind,cur_sear_ind)):
            processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)
            counter = np.zeros((label_size,),dtype = np.uint16)
            cur_sear_ind = cur_ind
        counter[sorted_inst_voxel_pair[i,2]] += 1
    processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)
    return processed_inst

def collate_fn_BEV(data):
    data2stack=np.stack([d[0] for d in data]).astype(np.float32)
    label2stack=np.stack([d[1] for d in data])
    center2stack=np.stack([d[2] for d in data])
    offset2stack=np.stack([d[3] for d in data])
    grid_ind_stack = [d[4] for d in data]
    point_label = [d[5] for d in data]
    point_inst = [d[6] for d in data]
    xyz = [d[7] for d in data]
    return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz

def collate_fn_BEV_test(data):    
    data2stack=np.stack([d[0] for d in data]).astype(np.float32)
    label2stack=np.stack([d[1] for d in data])
    center2stack=np.stack([d[2] for d in data])
    offset2stack=np.stack([d[3] for d in data])
    grid_ind_stack = [d[4] for d in data]
    point_label = [d[5] for d in data]
    point_inst = [d[6] for d in data]
    xyz = [d[7] for d in data]
    index = [d[8] for d in data]
    return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz,index

# load Semantic KITTI class info
with open("semantic-kitti.yaml", 'r') as stream:
    semkittiyaml = yaml.safe_load(stream)
SemKITTI_label_name = dict()
for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]:
    SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i]

================================================
FILE: dataloader/instance_augmentation.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pickle

class instance_augmentation(object):
    def __init__(self,instance_pkl_path,thing_list,class_weight,random_flip = False,random_add = False,random_rotate = False,local_transformation = False):
        self.thing_list = thing_list
        self.instance_weight = [class_weight[thing_class_num-1] for thing_class_num in thing_list]
        self.instance_weight = np.asarray(self.instance_weight)/np.sum(self.instance_weight)
        self.random_flip = random_flip
        self.random_add = random_add
        self.random_rotate = random_rotate
        self.local_transformation = local_transformation

        self.add_num = 5

        with open(instance_pkl_path, 'rb') as f:
            self.instance_path = pickle.load(f)

    def instance_aug(self, point_xyz, point_label, point_inst, point_feat = None):
        """random rotate and flip each instance independently.

        Args:
            point_xyz: [N, 3], point location
            point_label: [N, 1], class label
            point_inst: [N, 1], instance label
        """        
        # random add instance to this scan
        if self.random_add:
            # choose which instance to add
            instance_choice = np.random.choice(len(self.thing_list),self.add_num,replace=True,p=self.instance_weight)
            uni_inst, uni_inst_count = np.unique(instance_choice,return_counts=True)
            add_idx = 1
            total_point_num = 0
            early_break = False
            for n, count in zip(uni_inst, uni_inst_count):
                # find random instance
                random_choice = np.random.choice(len(self.instance_path[self.thing_list[n]]),count)
                # add to current scan
                for idx in random_choice:
                    points = np.fromfile(self.instance_path[self.thing_list[n]][idx], dtype=np.float32).reshape((-1, 4))
                    add_xyz = points[:,:3]
                    center = np.mean(add_xyz,axis=0)

                    # need to check occlusion
                    fail_flag = True
                    if self.random_rotate:
                        # random rotate
                        random_choice = np.random.random(20)*np.pi*2
                        for r in random_choice:
                            center_r = self.rotate_origin(center[np.newaxis,...],r)
                            # check if occluded
                            if self.check_occlusion(point_xyz,center_r[0]):
                                fail_flag = False
                                break
                        # rotate to empty space
                        if fail_flag: continue
                        add_xyz = self.rotate_origin(add_xyz,r)
                    else:
                        fail_flag = not self.check_occlusion(point_xyz,center)
                    if fail_flag: continue

                    add_label = np.ones((points.shape[0],),dtype=np.uint8)*(self.thing_list[n])
                    add_inst = np.ones((points.shape[0],),dtype=np.uint32)*(add_idx<<16)
                    point_xyz = np.concatenate((point_xyz,add_xyz),axis=0)
                    point_label = np.concatenate((point_label,add_label),axis=0)
                    point_inst = np.concatenate((point_inst,add_inst),axis=0)
                    if point_feat is not None:
                        add_fea =  points[:,3:]
                        if len(add_fea.shape) == 1: add_fea = add_fea[..., np.newaxis]
                        point_feat = np.concatenate((point_feat,add_fea),axis=0)
                    add_idx +=1
                    total_point_num += points.shape[0]
                    if total_point_num>5000:
                        early_break=True
                        break
                # prevent adding too many points which cause GPU memory error
                if early_break: break

        # instance mask
        mask = np.zeros_like(point_label,dtype=bool)
        for label in self.thing_list:
            mask[point_label == label] = True

        # create unqiue instance list
        inst_label = point_inst[mask].squeeze()
        unique_label = np.unique(inst_label)
        num_inst = len(unique_label)

        
        for inst in unique_label:
            # get instance index
            index = np.where(point_inst == inst)[0]
            # skip small instance
            if index.size<10: continue
            # get center
            center = np.mean(point_xyz[index,:],axis=0)

            if self.local_transformation:
                # random translation and rotation
                point_xyz[index,:] = self.local_tranform(point_xyz[index,:],center)
            
            # random flip instance based on it center 
            if self.random_flip:
                # get axis
                long_axis = [center[0], center[1]]/(center[0]**2+center[1]**2)**0.5
                short_axis = [-long_axis[1],long_axis[0]]
                # random flip
                flip_type = np.random.choice(5,1)
                if flip_type==3:
                    point_xyz[index,:2] = self.instance_flip(point_xyz[index,:2],[long_axis,short_axis],[center[0], center[1]],flip_type)
            
            # 20% random rotate
            random_num = np.random.random_sample()
            if self.random_rotate:
                if random_num>0.8 and inst & 0xFFFF > 0:
                    random_choice = np.random.random(20)*np.pi*2
                    fail_flag = True
                    for r in random_choice:
                        center_r = self.rotate_origin(center[np.newaxis,...],r)
                        # check if occluded
                        if self.check_occlusion(np.delete(point_xyz, index, axis=0),center_r[0]):
                            fail_flag = False
                            break
                    if not fail_flag:
                        # rotate to empty space
                        point_xyz[index,:] = self.rotate_origin(point_xyz[index,:],r)

        if len(point_label.shape) == 1: point_label = point_label[..., np.newaxis]
        if len(point_inst.shape) == 1: point_inst = point_inst[..., np.newaxis]
        if point_feat is not None:
            return point_xyz,point_label,point_inst,point_feat
        else:
            return point_xyz,point_label,point_inst

    def instance_flip(self, points,axis,center,flip_type = 1):
        points = points[:]-center
        if flip_type == 1:
            # rotate 180 degree
            points = -points+center
        elif flip_type == 2:
            # flip over long axis
            a = axis[0][0]
            b = axis[0][1]
            flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])
            points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))
            points = np.transpose(points, (1, 0))+center
        elif flip_type == 3:
            # flip over short axis
            a = axis[1][0]
            b = axis[1][1]
            flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])
            points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))
            points = np.transpose(points, (1, 0))+center

        return points

    def check_occlusion(self,points,center,min_dist=2):
        'check if close to a point'
        if points.ndim == 1:
            dist = np.linalg.norm(points[np.newaxis,:]-center,axis=1)
        else:
            dist = np.linalg.norm(points-center,axis=1)
        return np.all(dist>min_dist)

    def rotate_origin(self,xyz,radians):
        'rotate a point around the origin'
        x = xyz[:,0]
        y = xyz[:,1]
        new_xyz = xyz.copy()
        new_xyz[:,0] = x * np.cos(radians) + y * np.sin(radians)
        new_xyz[:,1] = -x * np.sin(radians) + y * np.cos(radians)
        return new_xyz

    def local_tranform(self,xyz,center):
        'translate and rotate point cloud according to its center'
        # random xyz
        loc_noise = np.random.normal(scale = 0.25, size=(1,3))
        # random angle
        rot_noise = np.random.uniform(-np.pi/20, np.pi/20)

        xyz = xyz-center
        xyz = self.rotate_origin(xyz,rot_noise)
        xyz = xyz+loc_noise
        
        return xyz+center


================================================
FILE: dataloader/process_panoptic.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np

class PanopticLabelGenerator(object):
    def __init__(self,grid_size,sigma=5,polar=False):
        """Initialize panoptic ground truth generator

        Args:
            grid_size: voxel size.
            sigma (int, optional):  Gaussian distribution paramter. Create heatmap in +-3*sigma window. Defaults to 5.
            polar (bool, optional): Is under polar coordinate. Defaults to False.
        """        
        self.grid_size = grid_size
        self.polar = polar

        self.sigma = sigma
        size = 6 * sigma + 3
        x = np.arange(0, size, 1, float)
        y = x[:, np.newaxis]
        x0, y0 = 3 * sigma + 1, 3 * sigma + 1
        self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    
    def __call__(self,inst,xyz,voxel_inst,voxel_position,label_dict,min_bound,intervals):
        """Generate instance center and offset ground truth

        Args:
            inst : instance panoptic label (N)
            xyz : point location (N x 3)
            voxel_inst : voxel panoptic label on the BEV (H x W)
            voxel_position : voxel location on the BEV (3 x H x W)
            label_dict : unqiue instance label dict
            min_bound : space minimal bound
            intervals : voxelization intervals

        Returns:
            center, center_pts, offset
        """        
        height, width = self.grid_size[0],self.grid_size[1]

        center = np.zeros((1, height, width), dtype=np.float32)
        center_pts = []
        offset = np.zeros((2, height, width), dtype=np.float32)
        #skip empty instances
        if inst.size < 2: return center, center_pts, offset
        # find unique instances
        inst_labels = np.unique(inst)
        for inst_label in inst_labels:
            # get mask for each unique instance
            mask = np.where(inst == inst_label)
            voxel_mask = np.where(voxel_inst == label_dict[inst_label])
            # get center
            center_x, center_y = np.mean(xyz[mask,0]), np.mean(xyz[mask,1])
            if self.polar:
                # convert to polar coordinate
                center_x_pol, center_y_pol = np.sqrt(center_x**2 + center_y**2),np.arctan2(center_y,center_x)
                center_x = center_x_pol
                center_y = center_y_pol

            # generate center heatmap
            x, y = int(np.floor((center_x-min_bound[0])/intervals[0])), int(np.floor((center_y-min_bound[1])/intervals[1]))
            center_pts.append([x, y])
            # outside image boundary
            if x < 0 or y < 0 or \
                    x >= height or y >= width:
                continue
            sigma = self.sigma
            # upper left
            ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1))
            # bottom right
            br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2))

            if self.polar:
                c, d = max(0, -ul[0]), min(br[0], height) - ul[0]
                a, b = 0, br[1] - ul[1]

                cc, dd = max(0, ul[0]), min(br[0], height)
                angle_list = [angle_id % width for angle_id in range(ul[1],br[1])]
                center[0, cc:dd, angle_list] = np.maximum(
                    center[0, cc:dd, angle_list], np.transpose(self.g[c:d,a:b]))
            else:
                c, d = max(0, -ul[0]), min(br[0], height) - ul[0]
                a, b = max(0, -ul[1]), min(br[1], width) - ul[1]

                cc, dd = max(0, ul[0]), min(br[0], height)
                aa, bb = max(0, ul[1]), min(br[1], width)
                center[0, cc:dd, aa:bb] = np.maximum(
                    center[0, cc:dd, aa:bb], self.g[c:d,a:b])

            if self.polar:
                # generate offset (2, h, w) -> (y-dir, x-dir)
                offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]
                offset[1,voxel_mask[0],voxel_mask[1]] = ((center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]]+np.pi)%(2*np.pi) - np.pi)/intervals[1]
            else:
                # generate offset (2, h, w) -> (y-dir, x-dir)
                offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]
                offset[1,voxel_mask[0],voxel_mask[1]] = (center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]])/intervals[1]

        # print('gt center',center_pts)

        return center, center_pts, offset

================================================
FILE: instance_preprocess.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from dataloader.dataset import SemKITTI

if __name__ == '__main__':
    # instance preprocessing
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-d', '--data_path', default='data')
    parser.add_argument('-o', '--out_path', default='data')

    args = parser.parse_args()

    train_pt_dataset = SemKITTI(args.data_path + '/sequences/', imageset = 'train', return_ref = True)
    train_pt_dataset.save_instance(args.out_path)
    print('instance preprocessing finished.')

================================================
FILE: network/BEV_Unet.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dropblock import DropBlock2D


class BEV_Unet(nn.Module):

    def __init__(self,n_class,n_height,dilation = 1,group_conv=False,input_batch_norm = False,dropout = 0.,circular_padding = False, dropblock = True, use_vis_fea=False):
        super(BEV_Unet, self).__init__()
        self.n_class = n_class
        self.n_height = n_height
        if use_vis_fea:
            self.network = UNet(n_class*n_height,2*n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)
        else:
            self.network = UNet(n_class*n_height,n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)

    def forward(self, x):
        x,center,offset = self.network(x)
        
        x = x.permute(0,2,3,1)
        new_shape = list(x.size())[:3] + [self.n_height,self.n_class]
        x = x.view(new_shape)
        x = x.permute(0,4,1,2,3)

        return x,center,offset
    
class UNet(nn.Module):
    def __init__(self, n_class,n_height,dilation,group_conv,input_batch_norm, dropout,circular_padding,dropblock):
        super(UNet, self).__init__()
        # encoder
        self.inc = inconv(n_height, 64, dilation, input_batch_norm, circular_padding)
        self.down1 = down(64, 128, dilation, group_conv, circular_padding)
        self.down2 = down(128, 256, dilation, group_conv, circular_padding)
        self.down3 = down(256, 512, dilation, group_conv, circular_padding)
        self.down4 = down(512, 512, dilation, group_conv, circular_padding)

        # semantic decoder
        self.up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.up4 = up(128, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.dropout = nn.Dropout(p=0. if dropblock else dropout)
        # semantic head
        self.outc = outconv(64, n_class)

        # instance decoder
        # self.i_up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        # self.i_up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        # self.i_up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        # self.i_up4 = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.i_up4_center = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        self.i_up4_offset = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
        # instance head
        self.i_outc_center = outconv(32, 1)
        self.i_outc_offset = outconv(32, 2)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        # semantic
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        s_x = self.up4(x, x1)
        s_x = self.outc(self.dropout(s_x))
        # instance
        # i_x = self.i_up1(x5, x4)
        # i_x = self.i_up2(i_x, x3)
        # i_x = self.i_up3(i_x, x2)x

        # i_x = self.i_up4(i_x, x1)
        i_x_center = self.i_up4_center(x, x1)
        i_x_center = self.i_outc_center(self.dropout(i_x_center))

        i_x_offset = self.i_up4_offset(x, x1)
        i_x_offset = self.i_outc_offset(self.dropout(i_x_offset))

        return s_x, i_x_center, i_x_offset

class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch,group_conv,dilation=1):
        super(double_conv, self).__init__()
        if group_conv:
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1,groups = min(out_ch,in_ch)),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1,groups = out_ch),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )

    def forward(self, x):
        x = self.conv(x)
        return x

class double_conv_circular(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch,group_conv,dilation=1):
        super(double_conv_circular, self).__init__()
        if group_conv:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=(1,0),groups = min(out_ch,in_ch)),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_ch, out_ch, 3, padding=(1,0),groups = out_ch),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=(1,0)),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_ch, out_ch, 3, padding=(1,0)),
                nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(inplace=True)
            )

    def forward(self, x):
        #add circular padding
        x = F.pad(x,(1,1,0,0),mode = 'circular')
        x = self.conv1(x)
        x = F.pad(x,(1,1,0,0),mode = 'circular')
        x = self.conv2(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch, dilation, input_batch_norm, circular_padding):
        super(inconv, self).__init__()
        if input_batch_norm:
            if circular_padding:
                self.conv = nn.Sequential(
                    nn.BatchNorm2d(in_ch),
                    double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)
                )
            else:
                self.conv = nn.Sequential(
                    nn.BatchNorm2d(in_ch),
                    double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)
                )
        else:
            if circular_padding:
                self.conv = double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)
            else:
                self.conv = double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch, dilation, group_conv, circular_padding):
        super(down, self).__init__()
        if circular_padding:
            self.mpconv = nn.Sequential(
                nn.MaxPool2d(2),
                double_conv_circular(in_ch, out_ch,group_conv = group_conv,dilation = dilation)
            )
        else:
            self.mpconv = nn.Sequential(
                nn.MaxPool2d(2),
                double_conv(in_ch, out_ch,group_conv = group_conv,dilation = dilation)
            )                

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, circular_padding, bilinear=True, group_conv=False, use_dropblock = False, drop_p = 0.5):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        elif group_conv:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2,groups = in_ch//2)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        if circular_padding:
            self.conv = double_conv_circular(in_ch, out_ch,group_conv = group_conv)
        else:
            self.conv = double_conv(in_ch, out_ch,group_conv = group_conv)

        self.use_dropblock = use_dropblock
        if self.use_dropblock:
            self.dropblock = DropBlock2D(block_size=7, drop_prob=drop_p)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        if self.use_dropblock:
            x = self.dropblock(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

================================================
FILE: network/__init__.py
================================================


================================================
FILE: network/instance_post_processing.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
import torch_scatter


def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=5, top_k=None, polar=False):
    """
    Find the center points from the center heatmap.
    Arguments:
        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
            for consistent, we only support N=1.
        threshold: A Float, threshold applied to center heatmap score.
        nms_kernel: An Integer, NMS max pooling kernel size.
        top_k: An Integer, top k centers to keep.
    Returns:
        A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).
    """
    if ctr_hmp.size(0) != 1:
        raise ValueError('Only supports inference for batch size = 1')

    # thresholding, setting values below threshold to -1
    ctr_hmp = F.threshold(ctr_hmp, threshold, -1)

    # NMS
    if polar:
        nms_padding = (nms_kernel - 1) // 2
        ctr_hmp_max_pooled = F.pad(ctr_hmp,(nms_padding,nms_padding,0,0),mode = 'circular')
        ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp_max_pooled, kernel_size=nms_kernel, stride=1, padding=(nms_padding,0))
    else:
        nms_padding = (nms_kernel - 1) // 2
        ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp, kernel_size=nms_kernel, stride=1, padding=nms_padding)
    ctr_hmp[ctr_hmp != ctr_hmp_max_pooled] = -1

    # squeeze first two dimensions
    ctr_hmp = ctr_hmp.squeeze()
    assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.'

    # find non-zero elements
    ctr_all = torch.nonzero(ctr_hmp > 0)
    if top_k is None:
        return ctr_all
    elif ctr_all.size(0) < top_k:
        return ctr_all
    else:
        # find top k centers.
        top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k)
        return torch.nonzero(ctr_hmp > top_k_scores[-1])


def group_pixels(ctr, offsets, polar=False):
    """
    Gives each pixel in the image an instance id.
    Arguments:
        ctr: A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).
        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
    Returns:
        A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).
    """
    if offsets.size(0) != 1:
        raise ValueError('Only supports inference for batch size = 1')

    offsets = offsets.squeeze(0)
    height, width = offsets.size()[1:]

    # generates a coordinate map, where each location is the coordinate of that loc
    y_coord = torch.arange(height, dtype=offsets.dtype, device=offsets.device).repeat(1, width, 1).transpose(1, 2)
    x_coord = torch.arange(width, dtype=offsets.dtype, device=offsets.device).repeat(1, height, 1)
    coord = torch.cat((y_coord, x_coord), dim=0)

    ctr_loc = coord + offsets
    ctr_loc = ctr_loc.reshape((2, height * width)).transpose(1, 0)

    # ctr: [K, 2] -> [K, 1, 2]
    # ctr_loc = [H*W, 2] -> [1, H*W, 2]
    ctr = ctr.unsqueeze(1)
    ctr_loc = ctr_loc.unsqueeze(0)

    # distance: [K, H*W]
    distance = ctr - ctr_loc
    if polar:
        distance[:,:,0] = torch.add(torch.fmod(torch.add(distance[:,:,0],width/2),width),-width/2)
    distance = torch.norm(distance, dim=-1)

    # finds center with minimum distance at each location, offset by 1, to reserve id=0 for stuff
    instance_id = torch.argmin(distance, dim=0).reshape((1, height, width)) + 1
    return instance_id


def get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, threshold=0.1, nms_kernel=5, top_k=None,
                              thing_seg=None, polar=False):
    """
    Post-processing for instance segmentation, gets class agnostic instance id map.
    Arguments:
        sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.
        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
            for consistent, we only support N=1.
        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
        thing_list: A List of thing class id.
        threshold: A Float, threshold applied to center heatmap score.
        nms_kernel: An Integer, NMS max pooling kernel size.
        top_k: An Integer, top k centers to keep.
        thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask, if not provided, inference from
            semantic prediction.
    Returns:
        A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).
        A Tensor of shape [1, K, 2] where K is the number of center points. The order of second dim is (y, x).
    """
    # if thing_seg is None:
    #     # gets foreground segmentation
    #     thing_seg = torch.zeros_like(sem_seg)
    #     for thing_class in thing_list:
    #         thing_seg[sem_seg == thing_class] = 1
    # if thing_seg.dim() == 4:
    #     # [1, H, W, Z] --> [1, H, W]
    #     thing_seg = torch.max(thing_seg,dim=3)

    ctr = find_instance_center(ctr_hmp, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k, polar=polar)
    if ctr.size(0) == 0:
        return torch.zeros_like(thing_seg[:,:,:,0]), ctr.unsqueeze(0)
    ins_seg = group_pixels(ctr, offsets, polar=polar)
    return ins_seg, ctr.unsqueeze(0)


def merge_semantic_and_instance(sem_seg, sem, ins_seg, label_divisor, thing_list, void_label,thing_seg):
    """
    Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic
        instance segmentation label.
    Arguments:
        sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.
        sem: A Tensor of shape [1, C, H, W, Z], predicted semantic logit.
        ins_seg: A Tensor of shape [1, H, W], predicted instance label.
        label_divisor: An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id.
        thing_list: A List of thing class id.
        void_label: An Integer, indicates the region has no confident prediction.
        thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask.
    Returns:
        A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel).
    Raises:
        ValueError, if batch size is not 1.
    """
    # In case thing mask does not align with semantic prediction
    # semantic_thing_seg = torch.zeros_like(sem_seg,dtype=torch.bool)
    # for thing_class in thing_list:
    #     semantic_thing_seg[sem_seg == thing_class] = True
    
    # try to avoid the for loop
    semantic_thing_seg = sem_seg<=max(thing_list)

    ins_seg = torch.unsqueeze(ins_seg,3).expand_as(sem_seg)
    thing_mask = (ins_seg > 0) & semantic_thing_seg & thing_seg
    if not torch.nonzero(thing_mask).size(0) == 0:
        sem_sum = torch_scatter.scatter_add(sem.permute(0,2,3,4,1)[thing_mask],ins_seg[thing_mask],dim=0)
        class_id = torch.argmax(sem_sum[:,:max(thing_list)],dim=1)
        sem_seg[thing_mask] = (ins_seg[thing_mask] * label_divisor) + class_id[ins_seg[thing_mask]]+1
    else:
        sem_seg[semantic_thing_seg & thing_seg] = void_label
    return sem_seg


def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor=2**16, void_label=0,
                              threshold=0.1, nms_kernel=5, top_k=100, foreground_mask=None, polar=False):
    """
    Post-processing for panoptic segmentation.
    Arguments:
        sem: A Tensor of shape [N, C, H, W, Z] of raw semantic output, where N is the batch size, for consistent,
            we only support N=1.
        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,
            for consistent, we only support N=1.
        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,
            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).
        thing_list: A List of thing class id.
        label_divisor: An Integer, used to convert panoptic id = instance_id * label_divisor + semantic_id.
        void_label: An Integer, indicates the region has no confident prediction.
        threshold: A Float, threshold applied to center heatmap score.
        nms_kernel: An Integer, NMS max pooling kernel size.
        top_k: An Integer, top k centers to keep.
        foreground_mask: A processed Tensor of shape [N, H, W, Z], we only support N=1.
    Returns:
        A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel), int64.
    Raises:
        ValueError, if batch size is not 1.
    """
    if sem.dim() != 5 and sem.dim() != 4:
        raise ValueError('Semantic prediction with un-supported dimension: {}.'.format(sem.dim()))
    if sem.dim() == 5 and sem.size(0) != 1:
        raise ValueError('Only supports inference for batch size = 1')
    if ctr_hmp.size(0) != 1:
        raise ValueError('Only supports inference for batch size = 1')
    if offsets.size(0) != 1:
        raise ValueError('Only supports inference for batch size = 1')
    if foreground_mask is not None:
        if foreground_mask.dim() != 4:
            raise ValueError('Foreground prediction with un-supported dimension: {}.'.format(sem.dim()))

    if sem.dim() == 5:
        semantic = torch.argmax(sem, dim=1)
        # shift back to original label idx 
        semantic = torch.add(semantic, 1)
        sem = F.softmax(sem)
    else:
        semantic = sem.type(torch.ByteTensor).cuda()
        # shift back to original label idx 
        semantic = torch.add(semantic, 1).type(torch.LongTensor).cuda()
        one_hot = torch.zeros((sem.size(0),torch.max(semantic).item()+1,sem.size(1),sem.size(2),sem.size(3))).cuda()
        sem = one_hot.scatter_(1,torch.unsqueeze(semantic,1),1.)
        sem = sem[:,1:,:,:,:]


    if foreground_mask is not None:
        thing_seg = foreground_mask
    else:
        thing_seg = None

    
    instance, center = get_instance_segmentation(semantic, ctr_hmp, offsets, thing_list,
                                                 threshold=threshold, nms_kernel=nms_kernel, top_k=top_k,
                                                 thing_seg=thing_seg, polar=polar)
    panoptic = merge_semantic_and_instance(semantic, sem, instance, label_divisor, thing_list, void_label, thing_seg)

    return panoptic, center


================================================
FILE: network/loss.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
from .lovasz_losses import lovasz_softmax

def _neg_loss(pred, gt):
    ''' Modified focal loss. Exactly the same as CornerNet.
        Runs faster and costs a little bit more memory
        (https://github.com/tianweiy/CenterPoint)
    Arguments:
        pred (batch x c x h x w)
        gt (batch x c x h x w)
    '''
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()

    neg_weights = torch.pow(1 - gt, 4)

    # loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
    return - (pos_loss + neg_loss)

class FocalLoss(torch.nn.Module):
    '''nn.Module warpper for focal loss'''
    def __init__(self):
        super(FocalLoss, self).__init__()
        self.neg_loss = _neg_loss

    def forward(self, out, target):
        return self.neg_loss(out, target)

class panoptic_loss(torch.nn.Module):
    def __init__(self, ignore_label = 255, center_loss_weight = 100, offset_loss_weight = 1, center_loss = 'MSE', offset_loss = 'L1'):
        super(panoptic_loss, self).__init__()
        self.CE_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
        assert center_loss in ['MSE','FocalLoss']
        assert offset_loss in ['L1','SmoothL1']
        if center_loss == 'MSE':
            self.center_loss_fn = torch.nn.MSELoss()
        elif center_loss == 'FocalLoss':
            self.center_loss_fn = FocalLoss()
        else: raise NotImplementedError
        if offset_loss == 'L1':
            self.offset_loss_fn = torch.nn.L1Loss()
        elif offset_loss == 'SmoothL1':
            self.offset_loss_fn = torch.nn.SmoothL1Loss()
        else: raise NotImplementedError
        self.center_loss_weight = center_loss_weight
        self.offset_loss_weight = offset_loss_weight

        print('Using '+ center_loss +' for heatmap regression, weight: '+str(center_loss_weight))
        print('Using '+ offset_loss +' for offset regression, weight: '+str(offset_loss_weight))

        self.lost_dict={'semantic_loss':[],
                        'heatmap_loss':[],
                        'offset_loss':[]}

    def reset_loss_dict(self):
        self.lost_dict={'semantic_loss':[],
                        'heatmap_loss':[],
                        'offset_loss':[]}

    def forward(self,prediction,center,offset,gt_label,gt_center,gt_offset,save_loss = True):
        # semantic loss
        loss = lovasz_softmax(torch.nn.functional.softmax(prediction), gt_label,ignore=255) + self.CE_loss(prediction,gt_label)
        if save_loss:
            self.lost_dict['semantic_loss'].append(loss.item())
        # center heatmap loss
        center_mask = (gt_center>0) | (torch.min(torch.unsqueeze(gt_label, 1),dim=4)[0]<255)
        center_loss = self.center_loss_fn(center,gt_center) * center_mask
        # safe division
        if center_mask.sum() > 0:
            center_loss = center_loss.sum() / center_mask.sum() * self.center_loss_weight
        else:
            center_loss = center_loss.sum() * 0
        if save_loss:
            self.lost_dict['heatmap_loss'].append(center_loss.item())
        loss += center_loss
        # offset loss
        offset_mask = gt_offset != 0
        offset_loss = self.offset_loss_fn(offset,gt_offset) * offset_mask
        # safe division
        if offset_mask.sum() > 0:
            offset_loss = offset_loss.sum() / offset_mask.sum() * self.offset_loss_weight
        else:
            offset_loss = offset_loss.sum() * 0
        if save_loss:
            self.lost_dict['offset_loss'].append(offset_loss.item())
        loss += offset_loss
        return loss

================================================
FILE: network/lovasz_losses.py
================================================
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

from __future__ import print_function, division

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse

def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / float(union)
        ious.append(iou)
    iou = mean(ious)    # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / float(union))
        ious.append(iou)
    ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
         super(StableBCELoss, self).__init__()
    def forward(self, input, target):
         neg_abs = - input.abs()
         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
         return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss


def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes is 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    elif probas.dim() == 5:
        #3D segmentation
        B, C, L, H, W = probas.size()
        probas = probas.contiguous().view(B, C, L, H*W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)

def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):
    """
    Something wrong with this loss
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    vprobas, vlabels = flatten_probas(probas, labels, ignore)
    
    
    true_1_hot = torch.eye(vprobas.shape[1])[vlabels]
    
    if bk_class:
        one_hot_assignment = torch.ones_like(vlabels)
        one_hot_assignment[vlabels == bk_class] = 0
        one_hot_assignment = one_hot_assignment.float().unsqueeze(1)
        true_1_hot = true_1_hot*one_hot_assignment
    
    true_1_hot = true_1_hot.to(vprobas.device)
    intersection = torch.sum(vprobas * true_1_hot)
    cardinality = torch.sum(vprobas + true_1_hot)
    loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()
    return (1-loss)*smooth

def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):
    """
    Multi-class Hinge Jaccard loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      ignore: void class labels
    """
    vprobas, vlabels = flatten_probas(probas, labels, ignore)
    C = vprobas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        if c in vlabels:
            c_sample_ind = vlabels == c
            cprobas = vprobas[c_sample_ind,:]
            non_c_ind =np.array([a for a in class_to_sum if a != c])
            class_pred = cprobas[:,c]
            max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]
            TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth
            FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)
            
            if (~c_sample_ind).sum() == 0:
                FP = 0
            else:
                nonc_probas = vprobas[~c_sample_ind,:]
                class_pred = nonc_probas[:,c]
                max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]
                FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)
            
            losses.append(1 - TP/(TP+FP+FN))
    
    if len(losses) == 0: return 0
    return mean(losses)

# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
    return x != x
    
    
def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n


================================================
FILE: network/ptBEV.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numba as nb
import multiprocessing
import torch_scatter

class ptBEVnet(nn.Module):
    
    def __init__(self, BEV_net, grid_size, pt_model = 'pointnet', fea_dim = 3, pt_pooling = 'max', kernal_size = 3,
                 out_pt_fea_dim = 64, max_pt_per_encode = 64, cluster_num = 4, pt_selection = 'farthest', fea_compre = None):
        super(ptBEVnet, self).__init__()
        assert pt_pooling in ['max']
        assert pt_selection in ['random','farthest']
        
        if pt_model == 'pointnet':
            self.PPmodel = nn.Sequential(
                nn.BatchNorm1d(fea_dim),
                
                nn.Linear(fea_dim, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(inplace=True),
                
                nn.Linear(64, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                
                nn.Linear(128, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(inplace=True),
                
                nn.Linear(256, out_pt_fea_dim)
            )
        
        self.pt_model = pt_model
        self.BEV_model = BEV_net
        self.pt_pooling = pt_pooling
        self.max_pt = max_pt_per_encode
        self.pt_selection = pt_selection
        self.fea_compre = fea_compre
        self.grid_size = grid_size
        
        # NN stuff
        if kernal_size != 1:
            if self.pt_pooling == 'max':
                self.local_pool_op = torch.nn.MaxPool2d(kernal_size, stride=1, padding=(kernal_size-1)//2, dilation=1)
            else: raise NotImplementedError
        else: self.local_pool_op = None
        
        # parametric pooling        
        if self.pt_pooling == 'max':
            self.pool_dim = out_pt_fea_dim
        
        # point feature compression
        if self.fea_compre is not None:
            self.fea_compression = nn.Sequential(
                    nn.Linear(self.pool_dim, self.fea_compre),
                    nn.ReLU())
            self.pt_fea_dim = self.fea_compre
        else:
            self.pt_fea_dim = self.pool_dim
        
    def forward(self, pt_fea, xy_ind, voxel_fea=None):
        cur_dev = pt_fea[0].get_device()
        
        # concate everything
        cat_pt_ind = []
        for i_batch in range(len(xy_ind)):
            cat_pt_ind.append(F.pad(xy_ind[i_batch],(1,0),'constant',value = i_batch))

        cat_pt_fea = torch.cat(pt_fea,dim = 0)
        cat_pt_ind = torch.cat(cat_pt_ind,dim = 0)
        pt_num = cat_pt_ind.shape[0]

        # shuffle the data
        shuffled_ind = torch.randperm(pt_num,device = cur_dev)
        cat_pt_fea = cat_pt_fea[shuffled_ind,:]
        cat_pt_ind = cat_pt_ind[shuffled_ind,:]
        
        # unique xy grid index
        unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind,return_inverse=True, return_counts=True, dim=0)
        unq = unq.type(torch.int64)
        
        # subsample pts
        if self.pt_selection == 'random':
            grp_ind = grp_range_torch(unq_cnt,cur_dev)[torch.argsort(torch.argsort(unq_inv))]
            remain_ind = grp_ind < self.max_pt
        elif self.pt_selection == 'farthest':
            unq_ind = np.split(np.argsort(unq_inv.detach().cpu().numpy()), np.cumsum(unq_cnt.detach().cpu().numpy()[:-1]))
            remain_ind = np.zeros((pt_num,),dtype = np.bool)
            np_cat_fea = cat_pt_fea.detach().cpu().numpy()[:,:3]
            pool_in = []
            for i_inds in unq_ind:
                if len(i_inds) > self.max_pt:
                    pool_in.append((np_cat_fea[i_inds,:],self.max_pt))
            if len(pool_in) > 0:
                pool = multiprocessing.Pool(multiprocessing.cpu_count())
                FPS_results = pool.starmap(parallel_FPS, pool_in)
                pool.close()
                pool.join()
            count = 0
            for i_inds in unq_ind:
                if len(i_inds) <= self.max_pt:
                    remain_ind[i_inds] = True
                else:
                    remain_ind[i_inds[FPS_results[count]]] = True
                    count += 1
            
        cat_pt_fea = cat_pt_fea[remain_ind,:]
        cat_pt_ind = cat_pt_ind[remain_ind,:]
        unq_inv = unq_inv[remain_ind]
        unq_cnt = torch.clamp(unq_cnt,max=self.max_pt)
        
        # process feature
        if self.pt_model == 'pointnet':
            processed_cat_pt_fea = self.PPmodel(cat_pt_fea)
        
        if self.pt_pooling == 'max':
            pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]
        else: raise NotImplementedError
        
        if self.fea_compre:
            processed_pooled_data = self.fea_compression(pooled_data)
        else:
            processed_pooled_data = pooled_data
        
        # stuff pooled data into 4D tensor
        out_data_dim = [len(pt_fea),self.grid_size[0],self.grid_size[1],self.pt_fea_dim]
        out_data = torch.zeros(out_data_dim, dtype=torch.float32).to(cur_dev)
        out_data[unq[:,0],unq[:,1],unq[:,2],:] = processed_pooled_data
        out_data = out_data.permute(0,3,1,2)
        if self.local_pool_op != None:
            out_data = self.local_pool_op(out_data)
        if voxel_fea is not None:
            out_data = torch.cat((out_data, voxel_fea), 1)
        
        # run through network
        sem_prediction, center, offset = self.BEV_model(out_data)
       
        return sem_prediction, center, offset
    
def grp_range_torch(a,dev):
    idx = torch.cumsum(a,0)
    id_arr = torch.ones(idx[-1],dtype = torch.int64,device=dev)
    id_arr[0] = 0
    id_arr[idx[:-1]] = -a[:-1]+1
    return torch.cumsum(id_arr,0)

def parallel_FPS(np_cat_fea,K):
    return  nb_greedy_FPS(np_cat_fea,K)

@nb.jit('b1[:](f4[:,:],i4)',nopython=True,cache=True)
def nb_greedy_FPS(xyz,K):
    start_element = 0
    sample_num = xyz.shape[0]
    sum_vec = np.zeros((sample_num,1),dtype = np.float32)
    xyz_sq = xyz**2
    for j in range(sample_num):
        sum_vec[j,0] = np.sum(xyz_sq[j,:])
    pairwise_distance = sum_vec + np.transpose(sum_vec) - 2*np.dot(xyz, np.transpose(xyz))
    
    candidates_ind = np.zeros((sample_num,),dtype = np.bool_)
    candidates_ind[start_element] = True
    remain_ind = np.ones((sample_num,),dtype = np.bool_)
    remain_ind[start_element] = False
    all_ind = np.arange(sample_num)
    
    for i in range(1,K):
        if i == 1:
            min_remain_pt_dis = pairwise_distance[:,start_element]
            min_remain_pt_dis = min_remain_pt_dis[remain_ind]
        else:
            cur_dis = pairwise_distance[remain_ind,:]
            cur_dis = cur_dis[:,candidates_ind]
            min_remain_pt_dis = np.zeros((cur_dis.shape[0],),dtype = np.float32)
            for j in range(cur_dis.shape[0]):
                min_remain_pt_dis[j] = np.min(cur_dis[j,:])
        next_ind_in_remain = np.argmax(min_remain_pt_dis)
        next_ind = all_ind[remain_ind][next_ind_in_remain]
        candidates_ind[next_ind] = True
        remain_ind[next_ind] = False
        
    return candidates_ind

================================================
FILE: pretrained_weight/Panoptic_SemKITTI_PolarNet.pt
================================================
[File too large to display: 52.5 MB]

================================================
FILE: requirements.txt
================================================
numpy
torch>=1.7.0
tqdm
pyyaml
numba>=0.39.0
torch_scatter>=1.3.1
Cython
scipy
dropblock

================================================
FILE: semantic-kitti.yaml
================================================
# This file is covered by the LICENSE file in the root of this project.
labels: 
  0 : "unlabeled"
  1 : "outlier"
  10: "car"
  11: "bicycle"
  13: "bus"
  15: "motorcycle"
  16: "on-rails"
  18: "truck"
  20: "other-vehicle"
  30: "person"
  31: "bicyclist"
  32: "motorcyclist"
  40: "road"
  44: "parking"
  48: "sidewalk"
  49: "other-ground"
  50: "building"
  51: "fence"
  52: "other-structure"
  60: "lane-marking"
  70: "vegetation"
  71: "trunk"
  72: "terrain"
  80: "pole"
  81: "traffic-sign"
  99: "other-object"
  252: "moving-car"
  253: "moving-bicyclist"
  254: "moving-person"
  255: "moving-motorcyclist"
  256: "moving-on-rails"
  257: "moving-bus"
  258: "moving-truck"
  259: "moving-other-vehicle"
color_map: # bgr
  0 : [0, 0, 0]
  1 : [0, 0, 255]
  10: [245, 150, 100]
  11: [245, 230, 100]
  13: [250, 80, 100]
  15: [150, 60, 30]
  16: [255, 0, 0]
  18: [180, 30, 80]
  20: [255, 0, 0]
  30: [30, 30, 255]
  31: [200, 40, 255]
  32: [90, 30, 150]
  40: [255, 0, 255]
  44: [255, 150, 255]
  48: [75, 0, 75]
  49: [75, 0, 175]
  50: [0, 200, 255]
  51: [50, 120, 255]
  52: [0, 150, 255]
  60: [170, 255, 150]
  70: [0, 175, 0]
  71: [0, 60, 135]
  72: [80, 240, 150]
  80: [150, 240, 255]
  81: [0, 0, 255]
  99: [255, 255, 50]
  252: [245, 150, 100]
  256: [255, 0, 0]
  253: [200, 40, 255]
  254: [30, 30, 255]
  255: [90, 30, 150]
  257: [250, 80, 100]
  258: [180, 30, 80]
  259: [255, 0, 0]
content: # as a ratio with the total number of points
  0: 0.018889854628292943
  1: 0.0002937197336781505
  10: 0.040818519255974316
  11: 0.00016609538710764618
  13: 2.7879693665067774e-05
  15: 0.00039838616015114444
  16: 0.0
  18: 0.0020633612104619787
  20: 0.0016218197275284021
  30: 0.00017698551338515307
  31: 1.1065903904919655e-08
  32: 5.532951952459828e-09
  40: 0.1987493871255525
  44: 0.014717169549888214
  48: 0.14392298360372
  49: 0.0039048553037472045
  50: 0.1326861944777486
  51: 0.0723592229456223
  52: 0.002395131480328884
  60: 4.7084144280367186e-05
  70: 0.26681502148037506
  71: 0.006035012012626033
  72: 0.07814222006271769
  80: 0.002855498193863172
  81: 0.0006155958086189918
  99: 0.009923127583046915
  252: 0.001789309418528068
  253: 0.00012709999297008662
  254: 0.00016059776092534436
  255: 3.745553104802113e-05
  256: 0.0
  257: 0.00011351574470342043
  258: 0.00010157861367183268
  259: 4.3840131989471124e-05
# classes that are indistinguishable from single scan or inconsistent in
# ground truth are mapped to their closest equivalent
learning_map:
  0 : 0     # "unlabeled"
  1 : 0     # "outlier" mapped to "unlabeled" --------------------------mapped
  10: 1     # "car"
  11: 2     # "bicycle"
  13: 5     # "bus" mapped to "other-vehicle" --------------------------mapped
  15: 3     # "motorcycle"
  16: 5     # "on-rails" mapped to "other-vehicle" ---------------------mapped
  18: 4     # "truck"
  20: 5     # "other-vehicle"
  30: 6     # "person"
  31: 7     # "bicyclist"
  32: 8     # "motorcyclist"
  40: 9     # "road"
  44: 10    # "parking"
  48: 11    # "sidewalk"
  49: 12    # "other-ground"
  50: 13    # "building"
  51: 14    # "fence"
  52: 0     # "other-structure" mapped to "unlabeled" ------------------mapped
  60: 9     # "lane-marking" to "road" ---------------------------------mapped
  70: 15    # "vegetation"
  71: 16    # "trunk"
  72: 17    # "terrain"
  80: 18    # "pole"
  81: 19    # "traffic-sign"
  99: 0     # "other-object" to "unlabeled" ----------------------------mapped
  252: 1    # "moving-car" to "car" ------------------------------------mapped
  253: 7    # "moving-bicyclist" to "bicyclist" ------------------------mapped
  254: 6    # "moving-person" to "person" ------------------------------mapped
  255: 8    # "moving-motorcyclist" to "motorcyclist" ------------------mapped
  256: 5    # "moving-on-rails" mapped to "other-vehicle" --------------mapped
  257: 5    # "moving-bus" mapped to "other-vehicle" -------------------mapped
  258: 4    # "moving-truck" to "truck" --------------------------------mapped
  259: 5    # "moving-other"-vehicle to "other-vehicle" ----------------mapped
learning_map_inv: # inverse of previous map
  0: 0      # "unlabeled", and others ignored
  1: 10     # "car"
  2: 11     # "bicycle"
  3: 15     # "motorcycle"
  4: 18     # "truck"
  5: 20     # "other-vehicle"
  6: 30     # "person"
  7: 31     # "bicyclist"
  8: 32     # "motorcyclist"
  9: 40     # "road"
  10: 44    # "parking"
  11: 48    # "sidewalk"
  12: 49    # "other-ground"
  13: 50    # "building"
  14: 51    # "fence"
  15: 70    # "vegetation"
  16: 71    # "trunk"
  17: 72    # "terrain"
  18: 80    # "pole"
  19: 81    # "traffic-sign"
learning_ignore: # Ignore classes
  0: True      # "unlabeled", and others ignored
  1: False     # "car"
  2: False     # "bicycle"
  3: False     # "motorcycle"
  4: False     # "truck"
  5: False     # "other-vehicle"
  6: False     # "person"
  7: False     # "bicyclist"
  8: False     # "motorcyclist"
  9: False     # "road"
  10: False    # "parking"
  11: False    # "sidewalk"
  12: False    # "other-ground"
  13: False    # "building"
  14: False    # "fence"
  15: False    # "vegetation"
  16: False    # "trunk"
  17: False    # "terrain"
  18: False    # "pole"
  19: False    # "traffic-sign"
thing_class: # thing class in panoptic segmentation
  0: False      # "unlabeled", and others ignored
  1: True     # "car"
  2: True     # "bicycle"
  3: True     # "motorcycle"
  4: True     # "truck"
  5: True     # "other-vehicle"
  6: True     # "person"
  7: True     # "bicyclist"
  8: True     # "motorcyclist"
  9: False     # "road"
  10: False    # "parking"
  11: False    # "sidewalk"
  12: False    # "other-ground"
  13: False    # "building"
  14: False    # "fence"
  15: False    # "vegetation"
  16: False    # "trunk"
  17: False    # "terrain"
  18: False    # "pole"
  19: False    # "traffic-sign"  
split: # sequence numbers
  train:
    - 0
    - 1
    - 2
    - 3
    - 4
    - 5
    - 6
    - 7
    - 9
    - 10
  valid:
    - 8
  test:
    - 11
    - 12
    - 13
    - 14
    - 15
    - 16
    - 17
    - 18
    - 19
    - 20
    - 21


================================================
FILE: test_pretrain.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import time
import argparse
import sys
import yaml
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
import errno

from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset,collate_fn_BEV_test
from network.instance_post_processing import get_panoptic_segmentation
from utils.eval_pq import PanopticEval
from utils.configs import merge_configs
#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")

def SemKITTI2train(label):
    if isinstance(label, list):
        return [SemKITTI2train_single(a) for a in label]
    else:
        return SemKITTI2train_single(label)

def SemKITTI2train_single(label):
    return label - 1 # uint8 trick

def main(args):
    data_path = args['dataset']['path']
    test_batch_size = args['model']['test_batch_size']
    pretrained_model = args['model']['pretrained_model']
    output_path = args['dataset']['output_path']
    compression_model = args['dataset']['grid_size'][2]
    grid_size = args['dataset']['grid_size']
    visibility = args['model']['visibility']
    pytorch_device = torch.device('cuda:0')
    if args['model']['polar']:
        fea_dim = 9
        circular_padding = True
    else:
        fea_dim = 7
        circular_padding = False

    # prepare miou fun
    unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]

    # prepare model
    my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)
    my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size =  grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,
                            out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)
    if os.path.exists(pretrained_model):
        my_model.load_state_dict(torch.load(pretrained_model))
    pytorch_total_params = sum(p.numel() for p in my_model.parameters())
    print('params: ',pytorch_total_params)
    my_model.to(pytorch_device)
    my_model.eval()

    # prepare dataset
    test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'test', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    if args['model']['polar']:
        test_dataset=spherical_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
        val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    else:
        test_dataset=voxel_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)
        val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    test_dataset_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                                    batch_size = test_batch_size,
                                                    collate_fn = collate_fn_BEV_test,
                                                    shuffle = False,
                                                    num_workers = 4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,
                                                    batch_size = test_batch_size,
                                                    collate_fn = collate_fn_BEV,
                                                    shuffle = False,
                                                    num_workers = 4)

    # validation
    print('*'*80)
    print('Test network performance on validation split')
    print('*'*80)
    pbar = tqdm(total=len(val_dataset_loader))
    time_list = []
    pp_time_list = []
    evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)
    with torch.no_grad():
        for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):
            val_vox_fea_ten = val_vox_fea.to(pytorch_device)
            val_vox_label = SemKITTI2train(val_vox_label)
            val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
            val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
            val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)
            val_gt_center_tensor = val_gt_center.to(pytorch_device)
            val_gt_offset_tensor = val_gt_offset.to(pytorch_device)

            torch.cuda.synchronize()
            start_time = time.time()
            if visibility:            
                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
            else:
                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)
            torch.cuda.synchronize()
            time_list.append(time.time()-start_time)

            for count,i_val_grid in enumerate(val_grid):
                # get foreground_mask
                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
                for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True
                # post processing
                torch.cuda.synchronize()
                start_time = time.time()
                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),val_pt_dataset.thing_list,\
                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                torch.cuda.synchronize()
                pp_time_list.append(time.time()-start_time)
                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
                panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]

                evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))
            del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points
            pbar.update(1)
    
    class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()
    miou,ious = evaluator.getSemIoU()
    print('Validation per class PQ, SQ, RQ and IoU: ')
    for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):
        print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))
    pbar.close()
    print('Current val PQ is %.3f' %
        (class_PQ*100))               
    print('Current val miou is %.3f'%
        (miou*100))
    print('Inference time per %d is %.4f seconds\n, postprocessing time is %.4f seconds per scan' %
        (test_batch_size,np.mean(time_list),np.mean(pp_time_list)))
    
    # test
    print('*'*80)
    print('Generate predictions for test split')
    print('*'*80)
    pbar = tqdm(total=len(test_dataset_loader))
    with torch.no_grad():
        for i_iter_test,(test_vox_fea,_,_,_,test_grid,_,_,test_pt_fea,test_index) in enumerate(test_dataset_loader):
            # predict
            test_vox_fea_ten = test_vox_fea.to(pytorch_device)
            test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea]
            test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid]

            if visibility:
                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten,test_vox_fea_ten)
            else:
                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten)
            # write to label file
            for count,i_test_grid in enumerate(test_grid):
                # get foreground_mask
                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
                for_mask[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] = True
                # post processing
                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),test_pt_dataset.thing_list,\
                                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)
                panoptic = panoptic_labels[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]]
                save_dir = test_pt_dataset.im_idx[test_index[count]]
                _,dir2 = save_dir.split('/sequences/',1)
                new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne','predictions')[:-3]+'label'
                if not os.path.exists(os.path.dirname(new_save_dir)):
                    try:
                        os.makedirs(os.path.dirname(new_save_dir))
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                panoptic.tofile(new_save_dir)
            del test_pt_fea_ten,test_grid_ten,test_pt_fea,predict_labels,center,offset
            pbar.update(1)
    pbar.close()
    print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path)
    print('Remapping script can be found in semantic-kitti-api.')

if __name__ == '__main__':
    # Testing settings
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-d', '--data_dir', default='data')
    parser.add_argument('-p', '--pretrained_model', default='pretrained_weight/Panoptic_SemKITTI_PolarNet.pt')
    parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')
    
    args = parser.parse_args()
    with open(args.configs, 'r') as s:
        new_args = yaml.safe_load(s)
    args = merge_configs(args,new_args)

    print(' '.join(sys.argv))
    print(args)
    main(args)

================================================
FILE: train.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import sys
import numpy as np
import yaml
import torch
import torch.optim as optim
from tqdm import tqdm

from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset
from network.instance_post_processing import get_panoptic_segmentation
from network.loss import panoptic_loss
from utils.eval_pq import PanopticEval
from utils.configs import merge_configs
#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")

def SemKITTI2train(label):
    if isinstance(label, list):
        return [SemKITTI2train_single(a) for a in label]
    else:
        return SemKITTI2train_single(label)

def SemKITTI2train_single(label):
    return label - 1 # uint8 trick

def load_pretrained_model(model,pretrained_model):
    model_dict = model.state_dict()
    pretrained_model = {k: v for k, v in pretrained_model.items() if k in model_dict}
    model_dict.update(pretrained_model) 
    model.load_state_dict(model_dict)
    return model

def main(args):
    data_path = args['dataset']['path']
    train_batch_size = args['model']['train_batch_size']
    val_batch_size = args['model']['val_batch_size']
    check_iter = args['model']['check_iter']
    model_save_path = args['model']['model_save_path']
    pretrained_model = args['model']['pretrained_model']
    compression_model = args['dataset']['grid_size'][2]
    grid_size = args['dataset']['grid_size']
    visibility = args['model']['visibility']
    pytorch_device = torch.device('cuda:0')
    if args['model']['polar']:
        fea_dim = 9
        circular_padding = True
    else:
        fea_dim = 7
        circular_padding = False

    #prepare miou fun
    unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1
    unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]

    #prepare model
    my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)
    my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size =  grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,
                            out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)
    if os.path.exists(model_save_path):
        my_model = load_pretrained_model(my_model,torch.load(model_save_path))
    elif os.path.exists(pretrained_model):
        my_model = load_pretrained_model(my_model,torch.load(pretrained_model))
    my_model.to(pytorch_device)

    optimizer = optim.Adam(my_model.parameters())
    loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\
                            center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss'])

    #prepare dataset
    train_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'train', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])
    if args['model']['polar']:
        train_dataset=spherical_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, use_aug = True)
        val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    else:
        train_dataset=voxel_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0,use_aug = True)
        val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)
    train_dataset_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                                    batch_size = train_batch_size,
                                                    collate_fn = collate_fn_BEV,
                                                    shuffle = True,
                                                    num_workers = 4)
    val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,
                                                    batch_size = val_batch_size,
                                                    collate_fn = collate_fn_BEV,
                                                    shuffle = False,
                                                    num_workers = 4)

    # training
    epoch=0
    best_val_PQ=0
    start_training=False
    my_model.train()
    global_iter = 0
    exce_counter = 0
    evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)

    while epoch < args['model']['max_epoch']:
        pbar = tqdm(total=len(train_dataset_loader))
        for i_iter,(train_vox_fea,train_label_tensor,train_gt_center,train_gt_offset,train_grid,_,_,train_pt_fea) in enumerate(train_dataset_loader):
            # validation
            if global_iter % check_iter == 0:
                my_model.eval()
                evaluator.reset()
                with torch.no_grad():
                    for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):
                        val_vox_fea_ten = val_vox_fea.to(pytorch_device)
                        val_vox_label = SemKITTI2train(val_vox_label)
                        val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]
                        val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]
                        val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)
                        val_gt_center_tensor = val_gt_center.to(pytorch_device)
                        val_gt_offset_tensor = val_gt_offset.to(pytorch_device)

                        if visibility:
                            predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)
                        else:
                            predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)

                        for count,i_val_grid in enumerate(val_grid):
                            # get foreground_mask
                            for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)
                            for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True
                            # post processing
                            panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\
                                                                                      val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\
                                                                                      top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)
                            panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.int32)
                            panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]
                            evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))
                        del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points
                my_model.train()
                class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()
                miou,ious = evaluator.getSemIoU()
                print('Validation per class PQ, SQ, RQ and IoU: ')
                for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):
                    print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))                                  
                # save model if performance is improved
                if best_val_PQ<class_PQ:
                    best_val_PQ=class_PQ
                    torch.save(my_model.state_dict(), model_save_path)
                print('Current val PQ is %.3f while the best val PQ is %.3f' %
                    (class_PQ*100,best_val_PQ*100))               
                print('Current val miou is %.3f'%
                    (miou*100))

                if start_training:
                    sem_l ,hm_l, os_l = np.mean(loss_fn.lost_dict['semantic_loss']), np.mean(loss_fn.lost_dict['heatmap_loss']), np.mean(loss_fn.lost_dict['offset_loss'])
                    print('epoch %d iter %5d, loss: %.3f, semantic loss: %.3f, heatmap loss: %.3f, offset loss: %.3f\n' %
                        (epoch, i_iter, sem_l+hm_l+os_l, sem_l, hm_l, os_l))
                print('%d exceptions encountered during last training\n' %
                    exce_counter)
                exce_counter = 0
                loss_fn.reset_loss_dict()

            # training
            try:
                train_vox_fea_ten = train_vox_fea.to(pytorch_device)
                train_label_tensor = SemKITTI2train(train_label_tensor)
                train_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in train_pt_fea]
                train_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in train_grid]
                train_label_tensor=train_label_tensor.type(torch.LongTensor).to(pytorch_device)
                train_gt_center_tensor = train_gt_center.to(pytorch_device)
                train_gt_offset_tensor = train_gt_offset.to(pytorch_device)

                if args['model']['enable_SAP'] and epoch>=args['model']['SAP']['start_epoch']:
                    for fea in train_pt_fea_ten:
                        fea.requires_grad_()
        
                # forward
                if visibility:
                    sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)
                else:
                    sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)
                # loss
                loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)
                
                
                # self adversarial pruning
                if args['model']['enable_SAP'] and epoch>=args['model']['SAP']['start_epoch']:
                    loss.backward()
                    for i,fea in enumerate(train_pt_fea_ten):
                        fea_grad = torch.norm(fea.grad,dim=1)
                        top_k_grad, _ = torch.topk(fea_grad, int(args['model']['SAP']['rate']*fea_grad.shape[0]))
                        # delete high influential points
                        train_pt_fea_ten[i] = train_pt_fea_ten[i][fea_grad < top_k_grad[-1]]
                        train_grid_ten[i] = train_grid_ten[i][fea_grad < top_k_grad[-1]]
                    optimizer.zero_grad()

                    # second pass
                    # forward
                    if visibility:
                        sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)
                    else:
                        sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)
                    # loss
                    loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)
                    
                # backward + optimize
                loss.backward()
                optimizer.step()
            except Exception as error: 
                if exce_counter == 0:
                    print(error)
                exce_counter += 1
            
            # zero the parameter gradients
            optimizer.zero_grad()
            pbar.update(1)
            start_training=True
            global_iter += 1
        pbar.close()
        epoch += 1

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-d', '--data_dir', default='data')
    parser.add_argument('-p', '--model_save_path', default='./Panoptic_SemKITTI.pt')
    parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')
    parser.add_argument('--pretrained_model', default='empty')

    args = parser.parse_args()
    with open(args.configs, 'r') as s:
        new_args = yaml.safe_load(s)
    args = merge_configs(args,new_args)

    print(' '.join(sys.argv))
    print(args)
    main(args)

================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/configs.py
================================================
#!/usr/bin/env python3
import yaml

def merge_configs(cfgs,new_cfgs):
    if hasattr(cfgs, 'data_dir'):
        new_cfgs['dataset']['path']=cfgs.data_dir
    if hasattr(cfgs, 'model_save_path'):
        new_cfgs['model']['model_save_path']=cfgs.model_save_path
    if hasattr(cfgs, 'pretrained_model'):
        new_cfgs['model']['pretrained_model']=cfgs.pretrained_model
    return new_cfgs

================================================
FILE: utils/eval_pq.py
================================================
#!/usr/bin/env python3
import numpy as np
import time


class PanopticEval:
  """ Panoptic evaluation using numpy
  
  authors: Andres Milioto and Jens Behley
  """

  def __init__(self, n_classes, device=None, ignore=None, offset=2**32, min_points=30):
    self.n_classes = n_classes
    assert (device == None)
    self.ignore = np.array(ignore, dtype=np.int64)
    self.include = np.array([n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64)

    print("[PANOPTIC EVAL] IGNORE: ", self.ignore)
    print("[PANOPTIC EVAL] INCLUDE: ", self.include)

    self.reset()
    self.offset = offset  # largest number of instances in a given scan
    self.min_points = min_points  # smallest number of points to consider instances in gt
    self.eps = 1e-15

  def num_classes(self):
    return self.n_classes

  def reset(self):
    # general things
    # iou stuff
    self.px_iou_conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64)
    # panoptic stuff
    self.pan_tp = np.zeros(self.n_classes, dtype=np.int64)
    self.pan_iou = np.zeros(self.n_classes, dtype=np.double)
    self.pan_fp = np.zeros(self.n_classes, dtype=np.int64)
    self.pan_fn = np.zeros(self.n_classes, dtype=np.int64)

  ################################# IoU STUFF ##################################
  def addBatchSemIoU(self, x_sem, y_sem):
    # idxs are labels and predictions
    idxs = np.stack([x_sem, y_sem], axis=0)

    # make confusion matrix (cols = gt, rows = pred)
    np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1)

  def getSemIoUStats(self):
    # clone to avoid modifying the real deal
    conf = self.px_iou_conf_matrix.copy().astype(np.double)
    # remove fp from confusion on the ignore classes predictions
    # points that were predicted of another class, but were ignore
    # (corresponds to zeroing the cols of those classes, since the predictions
    # go on the rows)
    conf[:, self.ignore] = 0

    # get the clean stats
    tp = conf.diagonal()
    fp = conf.sum(axis=1) - tp
    fn = conf.sum(axis=0) - tp
    return tp, fp, fn

  def getSemIoU(self):
    tp, fp, fn = self.getSemIoUStats()
    # print(f"tp={tp}")
    # print(f"fp={fp}")
    # print(f"fn={fn}")
    intersection = tp
    union = tp + fp + fn
    union = np.maximum(union, self.eps)
    iou = intersection.astype(np.double) / union.astype(np.double)
    iou_mean = (intersection[self.include].astype(np.double) / union[self.include].astype(np.double)).mean()

    return iou_mean, iou  # returns "iou mean", "iou per class" ALL CLASSES

  def getSemAcc(self):
    tp, fp, fn = self.getSemIoUStats()
    total_tp = tp.sum()
    total = tp[self.include].sum() + fp[self.include].sum()
    total = np.maximum(total, self.eps)
    acc_mean = total_tp.astype(np.double) / total.astype(np.double)

    return acc_mean  # returns "acc mean"

  ################################# IoU STUFF ##################################
  ##############################################################################

  #############################  Panoptic STUFF ################################
  def addBatchPanoptic(self, x_sem_row, x_inst_row, y_sem_row, y_inst_row):
    # make sure instances are not zeros (it messes with my approach)
    x_inst_row = x_inst_row + 1
    y_inst_row = y_inst_row + 1

    # only interested in points that are outside the void area (not in excluded classes)
    for cl in self.ignore:
      # make a mask for this class
      gt_not_in_excl_mask = y_sem_row != cl
      # remove all other points
      x_sem_row = x_sem_row[gt_not_in_excl_mask]
      y_sem_row = y_sem_row[gt_not_in_excl_mask]
      x_inst_row = x_inst_row[gt_not_in_excl_mask]
      y_inst_row = y_inst_row[gt_not_in_excl_mask]

    # first step is to count intersections > 0.5 IoU for each class (except the ignored ones)
    for cl in self.include:
      # print("*"*80)
      # print("CLASS", cl.item())
      # get a class mask
      x_inst_in_cl_mask = x_sem_row == cl
      y_inst_in_cl_mask = y_sem_row == cl

      # get instance points in class (makes outside stuff 0)
      x_inst_in_cl = x_inst_row * x_inst_in_cl_mask.astype(np.int64)
      y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64)

      # generate the areas for each unique instance prediction
      unique_pred, counts_pred = np.unique(x_inst_in_cl[x_inst_in_cl > 0], return_counts=True)
      id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}
      matched_pred = np.array([False] * unique_pred.shape[0])
      # print("Unique predictions:", unique_pred)

      # generate the areas for each unique instance gt_np
      unique_gt, counts_gt = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True)
      id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}
      matched_gt = np.array([False] * unique_gt.shape[0])
      # print("Unique ground truth:", unique_gt)

      # generate intersection using offset
      valid_combos = np.logical_and(x_inst_in_cl > 0, y_inst_in_cl > 0)
      offset_combo = x_inst_in_cl[valid_combos] + self.offset * y_inst_in_cl[valid_combos]
      unique_combo, counts_combo = np.unique(offset_combo, return_counts=True)

      # generate an intersection map
      # count the intersections with over 0.5 IoU as TP
      gt_labels = unique_combo // self.offset
      pred_labels = unique_combo % self.offset
      gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])
      pred_areas = np.array([counts_pred[id2idx_pred[id]] for id in pred_labels])
      intersections = counts_combo
      unions = gt_areas + pred_areas - intersections
      ious = intersections.astype(np.float) / unions.astype(np.float)


      tp_indexes = ious > 0.5
      self.pan_tp[cl] += np.sum(tp_indexes)
      self.pan_iou[cl] += np.sum(ious[tp_indexes])

      matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True
      matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True

      # count the FN
      self.pan_fn[cl] += np.sum(np.logical_and(counts_gt >= self.min_points, matched_gt == False))

      # count the FP
      self.pan_fp[cl] += np.sum(np.logical_and(counts_pred >= self.min_points, matched_pred == False))

  def getPQ(self):
    # first calculate for all classes
    sq_all = self.pan_iou.astype(np.double) / np.maximum(self.pan_tp.astype(np.double), self.eps)
    rq_all = self.pan_tp.astype(np.double) / np.maximum(
        self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double),
        self.eps)
    pq_all = sq_all * rq_all

    # then do the REAL mean (no ignored classes)
    SQ = sq_all[self.include].mean()
    RQ = rq_all[self.include].mean()
    PQ = pq_all[self.include].mean()

    return PQ, SQ, RQ, pq_all, sq_all, rq_all

  #############################  Panoptic STUFF ################################
  ##############################################################################

  def addBatch(self, x_sem, x_inst, y_sem, y_inst):  # x=preds, y=targets
    ''' IMPORTANT: Inputs must be batched. Either [N,H,W], or [N, P]
    '''
    # add to IoU calculation (for checking purposes)
    self.addBatchSemIoU(x_sem, y_sem)

    # now do the panoptic stuff
    self.addBatchPanoptic(x_sem, x_inst, y_sem, y_inst)


if __name__ == "__main__":
  # generate problem from He paper (https://arxiv.org/pdf/1801.00868.pdf)
  classes = 5  # ignore, grass, sky, person, dog
  cl_strings = ["ignore", "grass", "sky", "person", "dog"]
  ignore = [0]  # only ignore ignore class
  min_points = 1  # for this example we care about all points

  # generate ground truth and prediction
  sem_pred = []
  inst_pred = []
  sem_gt = []
  inst_gt = []

  # some ignore stuff
  N_ignore = 50
  sem_pred.extend([0 for i in range(N_ignore)])
  inst_pred.extend([0 for i in range(N_ignore)])
  sem_gt.extend([0 for i in range(N_ignore)])
  inst_gt.extend([0 for i in range(N_ignore)])

  # grass segment
  N_grass = 50
  N_grass_pred = 40  # rest is sky
  sem_pred.extend([1 for i in range(N_grass_pred)])  # grass
  sem_pred.extend([2 for i in range(N_grass - N_grass_pred)])  # sky
  inst_pred.extend([0 for i in range(N_grass)])
  sem_gt.extend([1 for i in range(N_grass)])  # grass
  inst_gt.extend([0 for i in range(N_grass)])

  # sky segment
  N_sky = 50
  N_sky_pred = 40  # rest is grass
  sem_pred.extend([2 for i in range(N_sky_pred)])  # sky
  sem_pred.extend([1 for i in range(N_sky - N_sky_pred)])  # grass
  inst_pred.extend([0 for i in range(N_sky)])  # first instance
  sem_gt.extend([2 for i in range(N_sky)])  # sky
  inst_gt.extend([0 for i in range(N_sky)])  # first instance

  # wrong dog as person prediction
  N_dog = 50
  N_person = N_dog
  sem_pred.extend([3 for i in range(N_person)])
  inst_pred.extend([35 for i in range(N_person)])
  sem_gt.extend([4 for i in range(N_dog)])
  inst_gt.extend([22 for i in range(N_dog)])

  # two persons in prediction, but three in gt
  N_person = 50
  sem_pred.extend([3 for i in range(6 * N_person)])
  inst_pred.extend([8 for i in range(4 * N_person)])
  inst_pred.extend([95 for i in range(2 * N_person)])
  sem_gt.extend([3 for i in range(6 * N_person)])
  inst_gt.extend([33 for i in range(3 * N_person)])
  inst_gt.extend([42 for i in range(N_person)])
  inst_gt.extend([11 for i in range(2 * N_person)])

  # gt and pred to numpy
  sem_pred = np.array(sem_pred, dtype=np.int64).reshape(1, -1)
  inst_pred = np.array(inst_pred, dtype=np.int64).reshape(1, -1)
  sem_gt = np.array(sem_gt, dtype=np.int64).reshape(1, -1)
  inst_gt = np.array(inst_gt, dtype=np.int64).reshape(1, -1)

  # evaluator
  evaluator = PanopticEval(classes, ignore=ignore, min_points=1)
  evaluator.addBatch(sem_pred, inst_pred, sem_gt, inst_gt)
  pq, sq, rq, all_pq, all_sq, all_rq = evaluator.getPQ()
  iou, all_iou = evaluator.getSemIoU()

  # [PANOPTIC EVAL] IGNORE:  [0]
  # [PANOPTIC EVAL] INCLUDE:  [1 2 3 4]
  # TOTALS
  # PQ: 0.47916666666666663
  # SQ: 0.5520833333333333
  # RQ: 0.6666666666666666
  # IoU: 0.5476190476190476
  # Class ignore 	 PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0
  # Class grass 	 PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666
  # Class sky 	 PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666
  # Class person 	 PQ: 0.5833333333333333 SQ: 0.875 RQ: 0.6666666666666666 IoU: 0.8571428571428571
  # Class dog 	 PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0

  print("TOTALS")
  print("PQ:", pq.item(), pq.item() == 0.47916666666666663)
  print("SQ:", sq.item(), sq.item() == 0.5520833333333333)
  print("RQ:", rq.item(), rq.item() == 0.6666666666666666)
  print("IoU:", iou.item(), iou.item() == 0.5476190476190476)
  for i, (pq, sq, rq, iou) in enumerate(zip(all_pq, all_sq, all_rq, all_iou)):
    print("Class", cl_strings[i], "\t", "PQ:", pq.item(), "SQ:", sq.item(), "RQ:", rq.item(), "IoU:", iou.item())

================================================
FILE: utils/visual.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import numpy as np

def flow_to_img(flow, normalize=True):
    """Convert flow to viewable image, using color hue to encode flow vector orientation, and color saturation to
    encode vector length. This is similar to the OpenCV tutorial on dense optical flow, except that they map vector
    length to the value plane of the HSV color model, instead of the saturation plane, as we do here.
    Args:
        flow: optical flow
        normalize: Normalize flow to 0..255
    Returns:
        img: viewable representation of the dense optical flow in RGB format
    Ref:
        https://github.com/philferriere/tfoptflow/blob/33e8a701e34c8ce061f17297d40619afbd459ade/tfoptflow/optflow.py
    """
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
    flow_magnitude, flow_angle = cv2.cartToPolar(flow[..., 0].astype(np.float32), flow[..., 1].astype(np.float32))

    # A couple times, we've gotten NaNs out of the above...
    nans = np.isnan(flow_magnitude)
    if np.any(nans):
        nans = np.where(nans)
        flow_magnitude[nans] = 0.

    # Normalize
    hsv[..., 0] = flow_angle * 180 / np.pi / 2
    if normalize is True:
        hsv[..., 1] = cv2.normalize(flow_magnitude, None, 0, 255, cv2.NORM_MINMAX)
    else:
        hsv[..., 1] = flow_magnitude
    hsv[..., 2] = 255
    img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

    return img
Download .txt
gitextract_pewizc2i/

├── LICENSE
├── README.md
├── configs/
│   └── SemanticKITTI_model/
│       └── Panoptic-PolarNet.yaml
├── data/
│   └── README.md
├── dataloader/
│   ├── __init__.py
│   ├── dataset.py
│   ├── instance_augmentation.py
│   └── process_panoptic.py
├── instance_preprocess.py
├── network/
│   ├── BEV_Unet.py
│   ├── __init__.py
│   ├── instance_post_processing.py
│   ├── loss.py
│   ├── lovasz_losses.py
│   └── ptBEV.py
├── pretrained_weight/
│   └── Panoptic_SemKITTI_PolarNet.pt
├── requirements.txt
├── semantic-kitti.yaml
├── test_pretrain.py
├── train.py
└── utils/
    ├── __init__.py
    ├── configs.py
    ├── eval_pq.py
    └── visual.py
Download .txt
SYMBOL INDEX (111 symbols across 13 files)

FILE: dataloader/dataset.py
  class SemKITTI (line 20) | class SemKITTI(data.Dataset):
    method __init__ (line 21) | def __init__(self, data_path, imageset = 'train', return_ref = False, ...
    method __len__ (line 53) | def __len__(self):
    method __getitem__ (line 57) | def __getitem__(self, index):
    method save_instance (line 72) | def save_instance(self, out_dir, min_points = 10):
  function absoluteFilePaths (line 118) | def absoluteFilePaths(directory):
  class voxel_dataset (line 123) | class voxel_dataset(data.Dataset):
    method __init__ (line 124) | def __init__(self, in_dataset, args, grid_size, ignore_label = 0, retu...
    method __len__ (line 143) | def __len__(self):
    method __getitem__ (line 147) | def __getitem__(self, index):
  function cart2polar (line 248) | def cart2polar(input_xyz):
  function polar2cat (line 253) | def polar2cat(input_xyz_polar):
  class spherical_dataset (line 258) | class spherical_dataset(data.Dataset):
    method __init__ (line 259) | def __init__(self, in_dataset, args, grid_size, ignore_label = 0, retu...
    method __len__ (line 278) | def __len__(self):
    method __getitem__ (line 282) | def __getitem__(self, index):
  function nb_process_label (line 405) | def nb_process_label(processed_label,sorted_label_voxel_pair):
  function nb_process_inst (line 421) | def nb_process_inst(processed_inst,sorted_inst_voxel_pair):
  function collate_fn_BEV (line 436) | def collate_fn_BEV(data):
  function collate_fn_BEV_test (line 447) | def collate_fn_BEV_test(data):

FILE: dataloader/instance_augmentation.py
  class instance_augmentation (line 7) | class instance_augmentation(object):
    method __init__ (line 8) | def __init__(self,instance_pkl_path,thing_list,class_weight,random_fli...
    method instance_aug (line 22) | def instance_aug(self, point_xyz, point_label, point_inst, point_feat ...
    method instance_flip (line 138) | def instance_flip(self, points,axis,center,flip_type = 1):
    method check_occlusion (line 160) | def check_occlusion(self,points,center,min_dist=2):
    method rotate_origin (line 168) | def rotate_origin(self,xyz,radians):
    method local_tranform (line 177) | def local_tranform(self,xyz,center):

FILE: dataloader/process_panoptic.py
  class PanopticLabelGenerator (line 5) | class PanopticLabelGenerator(object):
    method __init__ (line 6) | def __init__(self,grid_size,sigma=5,polar=False):
    method __call__ (line 24) | def __call__(self,inst,xyz,voxel_inst,voxel_position,label_dict,min_bo...

FILE: network/BEV_Unet.py
  class BEV_Unet (line 11) | class BEV_Unet(nn.Module):
    method __init__ (line 13) | def __init__(self,n_class,n_height,dilation = 1,group_conv=False,input...
    method forward (line 22) | def forward(self, x):
  class UNet (line 32) | class UNet(nn.Module):
    method __init__ (line 33) | def __init__(self, n_class,n_height,dilation,group_conv,input_batch_no...
    method forward (line 62) | def forward(self, x):
  class double_conv (line 88) | class double_conv(nn.Module):
    method __init__ (line 90) | def __init__(self, in_ch, out_ch,group_conv,dilation=1):
    method forward (line 111) | def forward(self, x):
  class double_conv_circular (line 115) | class double_conv_circular(nn.Module):
    method __init__ (line 117) | def __init__(self, in_ch, out_ch,group_conv,dilation=1):
    method forward (line 142) | def forward(self, x):
  class inconv (line 151) | class inconv(nn.Module):
    method __init__ (line 152) | def __init__(self, in_ch, out_ch, dilation, input_batch_norm, circular...
    method forward (line 171) | def forward(self, x):
  class down (line 176) | class down(nn.Module):
    method __init__ (line 177) | def __init__(self, in_ch, out_ch, dilation, group_conv, circular_paddi...
    method forward (line 190) | def forward(self, x):
  class up (line 195) | class up(nn.Module):
    method __init__ (line 196) | def __init__(self, in_ch, out_ch, circular_padding, bilinear=True, gro...
    method forward (line 217) | def forward(self, x1, x2):
  class outconv (line 238) | class outconv(nn.Module):
    method __init__ (line 239) | def __init__(self, in_ch, out_ch):
    method forward (line 243) | def forward(self, x):

FILE: network/instance_post_processing.py
  function find_instance_center (line 8) | def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=5, top_k=Non...
  function group_pixels (line 52) | def group_pixels(ctr, offsets, polar=False):
  function get_instance_segmentation (line 92) | def get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, thr...
  function merge_semantic_and_instance (line 128) | def merge_semantic_and_instance(sem_seg, sem, ins_seg, label_divisor, th...
  function get_panoptic_segmentation (line 164) | def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_d...

FILE: network/loss.py
  function _neg_loss (line 8) | def _neg_loss(pred, gt):
  class FocalLoss (line 27) | class FocalLoss(torch.nn.Module):
    method __init__ (line 29) | def __init__(self):
    method forward (line 33) | def forward(self, out, target):
  class panoptic_loss (line 36) | class panoptic_loss(torch.nn.Module):
    method __init__ (line 37) | def __init__(self, ignore_label = 255, center_loss_weight = 100, offse...
    method reset_loss_dict (line 62) | def reset_loss_dict(self):
    method forward (line 67) | def forward(self,prediction,center,offset,gt_label,gt_center,gt_offset...

FILE: network/lovasz_losses.py
  function lovasz_grad (line 17) | def lovasz_grad(gt_sorted):
  function iou_binary (line 32) | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
  function iou (line 52) | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
  function lovasz_hinge (line 77) | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
  function lovasz_hinge_flat (line 93) | def lovasz_hinge_flat(logits, labels):
  function flatten_binary_scores (line 113) | def flatten_binary_scores(scores, labels, ignore=None):
  class StableBCELoss (line 128) | class StableBCELoss(torch.nn.modules.Module):
    method __init__ (line 129) | def __init__(self):
    method forward (line 131) | def forward(self, input, target):
  function binary_xloss (line 137) | def binary_xloss(logits, labels, ignore=None):
  function lovasz_softmax (line 152) | def lovasz_softmax(probas, labels, classes='present', per_image=False, i...
  function lovasz_softmax_flat (line 170) | def lovasz_softmax_flat(probas, labels, classes='present'):
  function flatten_probas (line 201) | def flatten_probas(probas, labels, ignore=None):
  function xloss (line 223) | def xloss(logits, labels, ignore=None):
  function jaccard_loss (line 229) | def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = No...
  function hinge_jaccard_loss (line 257) | def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', ...
  function isnan (line 294) | def isnan(x):
  function mean (line 298) | def mean(l, ignore_nan=False, empty=0):

FILE: network/ptBEV.py
  class ptBEVnet (line 11) | class ptBEVnet(nn.Module):
    method __init__ (line 13) | def __init__(self, BEV_net, grid_size, pt_model = 'pointnet', fea_dim ...
    method forward (line 66) | def forward(self, pt_fea, xy_ind, voxel_fea=None):
  function grp_range_torch (line 145) | def grp_range_torch(a,dev):
  function parallel_FPS (line 152) | def parallel_FPS(np_cat_fea,K):
  function nb_greedy_FPS (line 156) | def nb_greedy_FPS(xyz,K):

FILE: test_pretrain.py
  function SemKITTI2train (line 24) | def SemKITTI2train(label):
  function SemKITTI2train_single (line 30) | def SemKITTI2train_single(label):
  function main (line 33) | def main(args):

FILE: train.py
  function SemKITTI2train (line 23) | def SemKITTI2train(label):
  function SemKITTI2train_single (line 29) | def SemKITTI2train_single(label):
  function load_pretrained_model (line 32) | def load_pretrained_model(model,pretrained_model):
  function main (line 39) | def main(args):

FILE: utils/configs.py
  function merge_configs (line 4) | def merge_configs(cfgs,new_cfgs):

FILE: utils/eval_pq.py
  class PanopticEval (line 6) | class PanopticEval:
    method __init__ (line 12) | def __init__(self, n_classes, device=None, ignore=None, offset=2**32, ...
    method num_classes (line 26) | def num_classes(self):
    method reset (line 29) | def reset(self):
    method addBatchSemIoU (line 40) | def addBatchSemIoU(self, x_sem, y_sem):
    method getSemIoUStats (line 47) | def getSemIoUStats(self):
    method getSemIoU (line 62) | def getSemIoU(self):
    method getSemAcc (line 75) | def getSemAcc(self):
    method addBatchPanoptic (line 88) | def addBatchPanoptic(self, x_sem_row, x_inst_row, y_sem_row, y_inst_row):
    method getPQ (line 156) | def getPQ(self):
    method addBatch (line 174) | def addBatch(self, x_sem, x_inst, y_sem, y_inst):  # x=preds, y=targets

FILE: utils/visual.py
  function flow_to_img (line 6) | def flow_to_img(flow, normalize=True):
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (133K chars).
[
  {
    "path": "LICENSE",
    "chars": 1520,
    "preview": "BSD 3-Clause License\n\nCopyright (c) 2020, Zixiang Zhou\nAll rights reserved.\n\nRedistribution and use in source and binary"
  },
  {
    "path": "README.md",
    "chars": 3876,
    "preview": "# Panoptic-PolarNet\nThis is the official implementation of Panoptic-PolarNet.\n\n[[**ArXiv paper**]](https://arxiv.org/abs"
  },
  {
    "path": "configs/SemanticKITTI_model/Panoptic-PolarNet.yaml",
    "chars": 878,
    "preview": "model_name: Panoptic_PolarNet\n\ndataset:\n    name: semantickitti\n    path: data\n    output_path: out/SemKITTI\n    instanc"
  },
  {
    "path": "data/README.md",
    "chars": 144,
    "preview": "# PolarSeg\n\nDownload Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/"
  },
  {
    "path": "dataloader/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "dataloader/dataset.py",
    "chars": 21945,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nSemKITTI dataloader\n\"\"\"\nimport os\nimport numpy as np\nimport torch\nimp"
  },
  {
    "path": "dataloader/instance_augmentation.py",
    "chars": 8273,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport pickle\n\nclass instance_augmentation(object):\n "
  },
  {
    "path": "dataloader/process_panoptic.py",
    "chars": 4538,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport numpy as np\n\nclass PanopticLabelGenerator(object):\n    def __init_"
  },
  {
    "path": "instance_preprocess.py",
    "chars": 565,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport argparse\nfrom dataloader.dataset import SemKITTI\n\nif __name__ == '"
  },
  {
    "path": "network/BEV_Unet.py",
    "chars": 9680,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.fu"
  },
  {
    "path": "network/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "network/instance_post_processing.py",
    "chars": 10581,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn.functional as F\nimport torch_scatter\n\n\ndef f"
  },
  {
    "path": "network/loss.py",
    "chars": 3724,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport torch\nfrom .lovasz_losses import lovasz_softma"
  },
  {
    "path": "network/lovasz_losses.py",
    "chars": 11484,
    "preview": "\"\"\"\nLovasz-Softmax and Jaccard hinge loss in PyTorch\nMaxim Berman 2018 ESAT-PSI KU Leuven (MIT License)\n\"\"\"\n\nfrom __futu"
  },
  {
    "path": "network/ptBEV.py",
    "chars": 7176,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport"
  },
  {
    "path": "requirements.txt",
    "chars": 88,
    "preview": "numpy\ntorch>=1.7.0\ntqdm\npyyaml\nnumba>=0.39.0\ntorch_scatter>=1.3.1\nCython\nscipy\ndropblock"
  },
  {
    "path": "semantic-kitti.yaml",
    "chars": 6162,
    "preview": "# This file is covered by the LICENSE file in the root of this project.\nlabels: \n  0 : \"unlabeled\"\n  1 : \"outlier\"\n  10:"
  },
  {
    "path": "test_pretrain.py",
    "chars": 11309,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport os\nimport time\nimport argparse\nimport sys\nimport yaml\nimport numpy"
  },
  {
    "path": "train.py",
    "chars": 13142,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport os\nimport argparse\nimport sys\nimport numpy as np\nimport yaml\nimpor"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/configs.py",
    "chars": 390,
    "preview": "#!/usr/bin/env python3\nimport yaml\n\ndef merge_configs(cfgs,new_cfgs):\n    if hasattr(cfgs, 'data_dir'):\n        new_cfgs"
  },
  {
    "path": "utils/eval_pq.py",
    "chars": 10857,
    "preview": "#!/usr/bin/env python3\nimport numpy as np\nimport time\n\n\nclass PanopticEval:\n  \"\"\" Panoptic evaluation using numpy\n  \n  a"
  },
  {
    "path": "utils/visual.py",
    "chars": 1425,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport cv2\nimport numpy as np\n\ndef flow_to_img(flow, normalize=True):\n   "
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the edwardzhou130/Panoptic-PolarNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (52.7 MB), approximately 34.2k tokens, and a symbol index with 111 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!