[
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2020, Zixiang Zhou\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "# Panoptic-PolarNet\nThis is the official implementation of Panoptic-PolarNet.\n\n[[**ArXiv paper**]](https://arxiv.org/abs/2103.14962)\n\n<p align=\"center\">\n        <img src=\"imgs/Visualization.gif\" width=\"60%\"> \n</p>\n\n# Introduction\n\nPanoptic-PolarNet is a fast and robust LiDAR point cloud panoptic segmentation framework. We learn both semantic segmentation and class-agnostic instance clustering in a single inference network using a polar Bird's Eye View (BEV) representation. Predictions from the semantic and instance head are then fused through a majority voting to create the final panopticsegmentation.\n\n<p align=\"center\">\n        <img src=\"imgs/CVPR_pipeline.png\" width=\"100%\"> \n</p>\n\nWe test Panoptic-PolarNet on SemanticKITTI and nuScenes datasets. Experiment shows that Panoptic-PolarNet reaches state-of-the-art performances with a real-time inference speed.\n\n<p align=\"center\">\n        <img src=\"imgs/result.png\" width=\"100%\"> \n</p>\n\n\n## Prepare dataset and environment\n\nThis code is tested on Ubuntu 16.04 with Python 3.8, CUDA 10.2 and Pytorch 1.7.0.\n\n1, Install the following dependencies by either `pip install -r requirements.txt` or manual installation.\n* numpy\n* pytorch\n* tqdm\n* yaml\n* Cython\n* [numba](https://github.com/numba/numba)\n* [torch-scatter](https://github.com/rusty1s/pytorch_scatter)\n* [dropblock](https://github.com/miguelvr/dropblock)\n* (Optional) [open3d](https://github.com/intel-isl/Open3D)\n\n2, Download Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).\n\n3, Extract everything into the same folder. The folder structure inside the zip files of label data matches the folder structure of the LiDAR point cloud data.\n\n4, Data file structure should look like this:\n\n```\n./\n├── train.py\n├── ...\n└── data/\n    ├──sequences\n        ├── 00/           \n        │   ├── velodyne/\t# Unzip from KITTI Odometry Benchmark Velodyne point clouds.\n        |   |\t├── 000000.bin\n        |   |\t├── 000001.bin\n        |   |\t└── ...\n        │   └── labels/ \t# Unzip from SemanticKITTI label data.\n        |       ├── 000000.label\n        |       ├── 000001.label\n        |       └── ...\n        ├── ...\n        └── 21/\n\t    └── ...\n```\n\n5, Instance preprocessing:\n```shell\npython instance_preprocess.py -d </your data path> -o </preprocessed file output path>\n``` \n\n## Training\n\nRun\n```shell\npython train.py\n```\n\nThe code will automatically train, validate and save the model that has the best validation PQ. \n\nPanoptic-PolarNet with default setting requires around 11GB GPU memory for the training. Training model on GPU with less memory would likely cause GPU out-of-memory. In this case, you can set the ``grid_size`` in the config file to ``[320,240,32]`` or lower.\n\n## Evaluate our pretrained model\n\nWe also provide a pretrained Panoptic-PolarNet weight.\n```shell\npython test_pretrain.py\n```\nResult will be stored in `./out` folder. Test performance can be evaluated by uploading label results onto the SemanticKITTI competition website [here](https://competitions.codalab.org/competitions/24025).\n\n## Citation\nPlease cite our paper if this code benefits your research:\n```\n@inproceedings{Zhou2021PanopticPolarNet,\nauthor={Zhou, Zixiang and Zhang, Yang and Foroosh, Hassan},\ntitle={Panoptic-PolarNet: Proposal-free LiDAR Point Cloud Panoptic Segmentation},\nbooktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\nyear={2021}\n}\n\n@InProceedings{Zhang_2020_CVPR,\nauthor = {Zhang, Yang and Zhou, Zixiang and David, Philip and Yue, Xiangyu and Xi, Zerong and Gong, Boqing and Foroosh, Hassan},\ntitle = {PolarNet: An Improved Grid Representation for Online LiDAR Point Clouds Semantic Segmentation},\nbooktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\nmonth = {June},\nyear = {2020}\n}\n```"
  },
  {
    "path": "configs/SemanticKITTI_model/Panoptic-PolarNet.yaml",
    "content": "model_name: Panoptic_PolarNet\n\ndataset:\n    name: semantickitti\n    path: data\n    output_path: out/SemKITTI\n    instance_pkl_path: data\n    rotate_aug: True\n    flip_aug: True\n    inst_aug: True\n    inst_aug_type:\n        inst_os: True\n        inst_loc_aug: True\n        inst_global_aug: True\n    gt_generator:\n        sigma: 5\n    grid_size: [480,360,32]\n\nmodel:\n    model_save_path: ./Panoptic_SemKITTI.pt\n    pretrained_model: /pretrained_weight/Panoptic_SemKITTI_PolarNet.pt\n    polar: True\n    visibility: True\n    \n    train_batch_size: 2\n    val_batch_size: 2\n    test_batch_size: 1\n    check_iter: 4000\n    max_epoch: 100\n    post_proc:\n        threshold: 0.1\n        nms_kernel: 5\n        top_k: 100\n    center_loss: MSE\n    offset_loss: L1\n    center_loss_weight: 100\n    offset_loss_weight: 10\n    enable_SAP: True\n    SAP:\n        start_epoch: 30\n        rate: 0.01"
  },
  {
    "path": "data/README.md",
    "content": "# PolarSeg\n\nDownload Velodyne point clouds and label data in SemanticKITTI dataset [here](http://www.semantic-kitti.org/dataset.html#overview).\n"
  },
  {
    "path": "dataloader/__init__.py",
    "content": ""
  },
  {
    "path": "dataloader/dataset.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nSemKITTI dataloader\n\"\"\"\nimport os\nimport numpy as np\nimport torch\nimport random\nimport time\nimport numba as nb\nimport yaml\nimport pickle\nimport errno\nfrom torch.utils import data\n\nfrom .process_panoptic import PanopticLabelGenerator\nfrom .instance_augmentation import instance_augmentation\n\nclass SemKITTI(data.Dataset):\n    def __init__(self, data_path, imageset = 'train', return_ref = False, instance_pkl_path ='data'):\n        self.return_ref = return_ref\n        with open(\"semantic-kitti.yaml\", 'r') as stream:\n            semkittiyaml = yaml.safe_load(stream)\n        self.learning_map = semkittiyaml['learning_map']\n        thing_class = semkittiyaml['thing_class']\n        self.thing_list = [cl for cl, ignored in thing_class.items() if ignored]\n        self.imageset = imageset\n        if imageset == 'train':\n            split = semkittiyaml['split']['train']\n        elif imageset == 'val':\n            split = semkittiyaml['split']['valid']\n        elif imageset == 'test':\n            split = semkittiyaml['split']['test']\n        else:\n            raise Exception('Split must be train/val/test')\n        \n        self.im_idx = []\n        for i_folder in split:\n            self.im_idx += absoluteFilePaths('/'.join([data_path,str(i_folder).zfill(2),'velodyne']))\n        self.im_idx.sort()\n\n        # get class distribution weight \n        epsilon_w = 0.001\n        origin_class = semkittiyaml['content'].keys()\n        weights = np.zeros((len(semkittiyaml['learning_map_inv'])-1,),dtype = np.float32)\n        for class_num in origin_class:\n            if semkittiyaml['learning_map'][class_num] != 0:\n                weights[semkittiyaml['learning_map'][class_num]-1] += semkittiyaml['content'][class_num]\n        self.CLS_LOSS_WEIGHT = 1/(weights + epsilon_w)\n        self.instance_pkl_path = instance_pkl_path\n         \n    def __len__(self):\n        'Denotes the total number of samples'\n        return len(self.im_idx)\n    \n    def __getitem__(self, index):\n        raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))\n        if self.imageset == 'test':\n            sem_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=int),axis=1)\n            inst_data = np.expand_dims(np.zeros_like(raw_data[:,0],dtype=np.uint32),axis=1)\n        else:\n            annotated_data = np.fromfile(self.im_idx[index].replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))\n            sem_data = annotated_data & 0xFFFF #delete high 16 digits binary\n            sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)\n            inst_data = annotated_data\n        data_tuple = (raw_data[:,:3], sem_data.astype(np.uint8),inst_data)\n        if self.return_ref:\n            data_tuple += (raw_data[:,3],)\n        return data_tuple\n\n    def save_instance(self, out_dir, min_points = 10):\n        'instance data preparation'\n        instance_dict={label:[] for label in self.thing_list}\n        for data_path in self.im_idx:\n            print('process instance for:'+data_path)\n            # get x,y,z,ref,semantic label and instance label\n            raw_data = np.fromfile(data_path, dtype=np.float32).reshape((-1, 4))\n            annotated_data = np.fromfile(data_path.replace('velodyne','labels')[:-3]+'label', dtype=np.uint32).reshape((-1,1))\n            sem_data = annotated_data & 0xFFFF #delete high 16 digits binary\n            sem_data = np.vectorize(self.learning_map.__getitem__)(sem_data)\n            inst_data = annotated_data\n\n            # instance mask\n            mask = np.zeros_like(sem_data,dtype=bool)\n            for label in self.thing_list:\n                mask[sem_data == label] = True\n\n            # create unqiue instance list\n            inst_label = inst_data[mask].squeeze()\n            unique_label = np.unique(inst_label)\n            num_inst = len(unique_label)\n\n            inst_count = 0\n            for inst in unique_label:\n                # get instance index\n                index = np.where(inst_data == inst)[0]\n                # get semantic label\n                class_label = sem_data[index[0]]\n                # skip small instance\n                if index.size<min_points: continue\n                # save\n                _,dir2 = data_path.split('/sequences/',1)\n                new_save_dir = out_dir + '/sequences/' +dir2.replace('velodyne','instance')[:-4]+'_'+str(inst_count)+'.bin'\n                if not os.path.exists(os.path.dirname(new_save_dir)):\n                    try:\n                        os.makedirs(os.path.dirname(new_save_dir))\n                    except OSError as exc:\n                        if exc.errno != errno.EEXIST:\n                            raise\n                inst_fea = raw_data[index]\n                inst_fea.tofile(new_save_dir)\n                instance_dict[int(class_label)].append(new_save_dir)\n                inst_count+=1\n        with open(out_dir+'/instance_path.pkl', 'wb') as f:\n            pickle.dump(instance_dict, f)\n\ndef absoluteFilePaths(directory):\n   for dirpath,_,filenames in os.walk(directory):\n       for f in filenames:\n           yield os.path.abspath(os.path.join(dirpath, f))\n\nclass voxel_dataset(data.Dataset):\n  def __init__(self, in_dataset, args, grid_size, ignore_label = 0, return_test = False, fixed_volume_space= True, use_aug = False, max_volume_space = [50,50,1.5], min_volume_space = [-50,-50,-3]):\n        'Initialization'\n        self.point_cloud_dataset = in_dataset\n        self.grid_size = np.asarray(grid_size)\n        self.rotate_aug = args['rotate_aug'] if use_aug else False\n        self.ignore_label = ignore_label\n        self.return_test = return_test\n        self.flip_aug = args['flip_aug'] if use_aug else False\n        self.instance_aug = args['inst_aug'] if use_aug else False\n        self.fixed_volume_space = fixed_volume_space\n        self.max_volume_space = max_volume_space\n        self.min_volume_space = min_volume_space\n\n        self.panoptic_proc = PanopticLabelGenerator(self.grid_size,sigma=args['gt_generator']['sigma'])\n        if self.instance_aug:\n            self.inst_aug = instance_augmentation(self.point_cloud_dataset.instance_pkl_path+'/instance_path.pkl',self.point_cloud_dataset.thing_list,self.point_cloud_dataset.CLS_LOSS_WEIGHT,\\\n                                                random_flip=args['inst_aug_type']['inst_global_aug'],random_add=args['inst_aug_type']['inst_os'],\\\n                                                random_rotate=args['inst_aug_type']['inst_global_aug'],local_transformation=args['inst_aug_type']['inst_loc_aug'])\n\n  def __len__(self):\n        'Denotes the total number of samples'\n        return len(self.point_cloud_dataset)\n\n  def __getitem__(self, index):\n        'Generates one sample of data'\n        data = self.point_cloud_dataset[index]\n        if len(data) == 3:\n            xyz,labels,insts = data\n        elif len(data) == 4:\n            xyz,labels,insts,feat = data\n            if len(feat.shape) == 1: feat = feat[..., np.newaxis]\n        else: raise Exception('Return invalid data tuple')\n        if len(labels.shape) == 1: labels = labels[..., np.newaxis]\n        if len(insts.shape) == 1: insts = insts[..., np.newaxis]\n        \n        # random data augmentation by rotation\n        if self.rotate_aug:\n            rotate_rad = np.deg2rad(np.random.random()*360)\n            c, s = np.cos(rotate_rad), np.sin(rotate_rad)\n            j = np.matrix([[c, s], [-s, c]])\n            xyz[:,:2] = np.dot( xyz[:,:2],j)\n\n        # random data augmentation by flip x , y or x+y\n        if self.flip_aug:\n            flip_type = np.random.choice(4,1)\n            if flip_type==1:\n                xyz[:,0] = -xyz[:,0]\n            elif flip_type==2:\n                xyz[:,1] = -xyz[:,1]\n            elif flip_type==3:\n                xyz[:,:2] = -xyz[:,:2]\n\n        # random instance augmentation\n        if self.instance_aug:\n            xyz,labels,insts,feat = self.inst_aug.instance_aug(xyz,labels.squeeze(),insts.squeeze(),feat)\n\n        max_bound = np.percentile(xyz,100,axis = 0)\n        min_bound = np.percentile(xyz,0,axis = 0)\n        \n        if self.fixed_volume_space:\n            max_bound = np.asarray(self.max_volume_space)\n            min_bound = np.asarray(self.min_volume_space)\n\n        # get grid index\n        crop_range = max_bound - min_bound\n        cur_grid_size = self.grid_size\n        \n        intervals = crop_range/(cur_grid_size-1)\n        if (intervals==0).any(): print(\"Zero interval!\")\n        \n        grid_ind = (np.floor((np.clip(xyz,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)\n\n        # process voxel position\n        voxel_position = np.zeros(self.grid_size,dtype = np.float32)\n        dim_array = np.ones(len(self.grid_size)+1,int)\n        dim_array[0] = -1 \n        voxel_position = (np.indices(self.grid_size) + 0.5)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)\n\n        # process labels\n        processed_label = np.ones(self.grid_size,dtype = np.uint8)*self.ignore_label\n        label_voxel_pair = np.concatenate([grid_ind,labels],axis = 1)\n        label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:,0],grid_ind[:,1],grid_ind[:,2])),:]\n        processed_label = nb_process_label(np.copy(processed_label),label_voxel_pair)\n\n        # get thing points mask\n        mask = np.zeros_like(labels,dtype=bool)\n        for label in self.point_cloud_dataset.thing_list:\n            mask[labels == label] = True\n        \n        inst_label = insts[mask].squeeze()\n        unique_label = np.unique(inst_label)\n        unique_label_dict = {label:idx+1 for idx , label in enumerate(unique_label)}\n        if inst_label.size > 1:            \n            inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)\n            \n            # process panoptic\n            processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label\n            inst_voxel_pair = np.concatenate([grid_ind[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)\n            inst_voxel_pair = inst_voxel_pair[np.lexsort((grid_ind[mask[:,0],0],grid_ind[mask[:,0],1])),:]\n            processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)\n        else:\n            processed_inst = None\n\n        center,center_points,offset = self.panoptic_proc(insts[mask],xyz[mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)\n        \n        data_tuple = (voxel_position,processed_label,center,offset)\n\n        # center data on each voxel for PTnet\n        voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound\n        return_xyz = xyz - voxel_centers\n        return_xyz = np.concatenate((return_xyz,xyz),axis = 1)\n        \n        if len(data) == 3:\n            return_fea = return_xyz\n        elif len(data) == 4:\n            return_fea = np.concatenate((return_xyz,feat),axis = 1)\n        \n        if self.return_test:\n            data_tuple += (grid_ind,labels,insts,return_fea,index)\n        else:\n            data_tuple += (grid_ind,labels,insts,return_fea)\n        return data_tuple\n\n# transformation between Cartesian coordinates and polar coordinates\ndef cart2polar(input_xyz):\n    rho = np.sqrt(input_xyz[:,0]**2 + input_xyz[:,1]**2)\n    phi = np.arctan2(input_xyz[:,1],input_xyz[:,0])\n    return np.stack((rho,phi,input_xyz[:,2]),axis=1)\n\ndef polar2cat(input_xyz_polar):\n    x = input_xyz_polar[0]*np.cos(input_xyz_polar[1])\n    y = input_xyz_polar[0]*np.sin(input_xyz_polar[1])\n    return np.stack((x,y,input_xyz_polar[2]),axis=0)\n\nclass spherical_dataset(data.Dataset):\n  def __init__(self, in_dataset, args, grid_size, ignore_label = 0, return_test = False, use_aug = False, fixed_volume_space= True, max_volume_space = [50,np.pi,1.5], min_volume_space = [3,-np.pi,-3]):\n        'Initialization'\n        self.point_cloud_dataset = in_dataset\n        self.grid_size = np.asarray(grid_size)\n        self.rotate_aug = args['rotate_aug'] if use_aug else False\n        self.flip_aug = args['flip_aug'] if use_aug else False\n        self.instance_aug = args['inst_aug'] if use_aug else False\n        self.ignore_label = ignore_label\n        self.return_test = return_test\n        self.fixed_volume_space = fixed_volume_space\n        self.max_volume_space = max_volume_space\n        self.min_volume_space = min_volume_space\n\n        self.panoptic_proc = PanopticLabelGenerator(self.grid_size,sigma=args['gt_generator']['sigma'],polar=True)\n        if self.instance_aug:\n            self.inst_aug = instance_augmentation(self.point_cloud_dataset.instance_pkl_path+'/instance_path.pkl',self.point_cloud_dataset.thing_list,self.point_cloud_dataset.CLS_LOSS_WEIGHT,\\\n                                                random_flip=args['inst_aug_type']['inst_global_aug'],random_add=args['inst_aug_type']['inst_os'],\\\n                                                random_rotate=args['inst_aug_type']['inst_global_aug'],local_transformation=args['inst_aug_type']['inst_loc_aug'])\n\n  def __len__(self):\n        'Denotes the total number of samples'\n        return len(self.point_cloud_dataset)\n\n  def __getitem__(self, index):\n        'Generates one sample of data'\n        data = self.point_cloud_dataset[index]\n        if len(data) == 3:\n            xyz,labels,insts = data\n        elif len(data) == 4:\n            xyz,labels,insts,feat = data\n            if len(feat.shape) == 1: feat = feat[..., np.newaxis]\n        else: raise Exception('Return invalid data tuple')\n        if len(labels.shape) == 1: labels = labels[..., np.newaxis]\n        if len(insts.shape) == 1: insts = insts[..., np.newaxis]\n        \n        # random data augmentation by rotation\n        if self.rotate_aug:\n            rotate_rad = np.deg2rad(np.random.random()*360)\n            c, s = np.cos(rotate_rad), np.sin(rotate_rad)\n            j = np.matrix([[c, s], [-s, c]])\n            xyz[:,:2] = np.dot( xyz[:,:2],j)\n\n        # random data augmentation by flip x , y or x+y\n        if self.flip_aug:\n            flip_type = np.random.choice(4,1)\n            if flip_type==1:\n                xyz[:,0] = -xyz[:,0]\n            elif flip_type==2:\n                xyz[:,1] = -xyz[:,1]\n            elif flip_type==3:\n                xyz[:,:2] = -xyz[:,:2]\n\n        # random instance augmentation\n        if self.instance_aug:\n            xyz,labels,insts,feat = self.inst_aug.instance_aug(xyz,labels.squeeze(),insts.squeeze(),feat)\n        \n        # convert coordinate into polar coordinates\n        xyz_pol = cart2polar(xyz)\n        \n        max_bound_r = np.percentile(xyz_pol[:,0],100,axis = 0)\n        min_bound_r = np.percentile(xyz_pol[:,0],0,axis = 0)\n        max_bound = np.max(xyz_pol[:,1:],axis = 0)\n        min_bound = np.min(xyz_pol[:,1:],axis = 0)\n        max_bound = np.concatenate(([max_bound_r],max_bound))\n        min_bound = np.concatenate(([min_bound_r],min_bound))\n        if self.fixed_volume_space:\n            max_bound = np.asarray(self.max_volume_space)\n            min_bound = np.asarray(self.min_volume_space)\n\n        # get grid index\n        crop_range = max_bound - min_bound\n        cur_grid_size = self.grid_size\n        intervals = crop_range/(cur_grid_size-1)\n\n        if (intervals==0).any(): print(\"Zero interval!\")\n        grid_ind = (np.floor((np.clip(xyz_pol,min_bound,max_bound)-min_bound)/intervals)).astype(np.int)\n\n        current_grid = grid_ind[:np.size(labels)]\n\n        # process voxel position\n        voxel_position = np.zeros(self.grid_size,dtype = np.float32)\n        dim_array = np.ones(len(self.grid_size)+1,int)\n        dim_array[0] = -1 \n        voxel_position = np.indices(self.grid_size)*intervals.reshape(dim_array) + min_bound.reshape(dim_array)\n        # voxel_position = polar2cat(voxel_position)\n        \n        # process labels\n        processed_label = np.ones(self.grid_size,dtype = np.uint8)*self.ignore_label\n        label_voxel_pair = np.concatenate([current_grid,labels],axis = 1)\n        label_voxel_pair = label_voxel_pair[np.lexsort((current_grid[:,0],current_grid[:,1],current_grid[:,2])),:]\n        processed_label = nb_process_label(np.copy(processed_label),label_voxel_pair)\n        # data_tuple = (voxel_position,processed_label)\n\n        # get thing points mask\n        mask = np.zeros_like(labels,dtype=bool)\n        for label in self.point_cloud_dataset.thing_list:\n            mask[labels == label] = True\n        \n        inst_label = insts[mask].squeeze()\n        unique_label = np.unique(inst_label)\n        unique_label_dict = {label:idx+1 for idx , label in enumerate(unique_label)}\n        if inst_label.size > 1:            \n            inst_label = np.vectorize(unique_label_dict.__getitem__)(inst_label)\n            \n            # process panoptic\n            processed_inst = np.ones(self.grid_size[:2],dtype = np.uint8)*self.ignore_label\n            inst_voxel_pair = np.concatenate([current_grid[mask[:,0],:2],inst_label[..., np.newaxis]],axis = 1)\n            inst_voxel_pair = inst_voxel_pair[np.lexsort((current_grid[mask[:,0],0],current_grid[mask[:,0],1])),:]\n            processed_inst = nb_process_inst(np.copy(processed_inst),inst_voxel_pair)\n        else:\n            processed_inst = None\n\n        center,center_points,offset = self.panoptic_proc(insts[mask],xyz[:np.size(labels)][mask[:,0]],processed_inst,voxel_position[:2,:,:,0],unique_label_dict,min_bound,intervals)\n\n        # prepare visiblity feature\n        # find max distance index in each angle,height pair\n        valid_label = np.zeros_like(processed_label,dtype=bool)\n        valid_label[current_grid[:,0],current_grid[:,1],current_grid[:,2]] = True\n        valid_label = valid_label[::-1]\n        max_distance_index = np.argmax(valid_label,axis=0)\n        max_distance = max_bound[0]-intervals[0]*(max_distance_index)\n        distance_feature = np.expand_dims(max_distance, axis=2)-np.transpose(voxel_position[0],(1,2,0))\n        distance_feature = np.transpose(distance_feature,(1,2,0))\n        # convert to boolean feature\n        distance_feature = (distance_feature>0)*-1.\n        distance_feature[current_grid[:,2],current_grid[:,0],current_grid[:,1]]=1.\n\n        data_tuple = (distance_feature,processed_label,center,offset)\n\n        # center data on each voxel for PTnet\n        voxel_centers = (grid_ind.astype(np.float32) + 0.5)*intervals + min_bound\n        return_xyz = xyz_pol - voxel_centers\n        return_xyz = np.concatenate((return_xyz,xyz_pol,xyz[:,:2]),axis = 1)\n\n        if len(data) == 3:\n            return_fea = return_xyz\n        elif len(data) == 4:\n            return_fea = np.concatenate((return_xyz,feat),axis = 1)\n        \n        if self.return_test:\n            data_tuple += (grid_ind,labels,insts,return_fea,index)\n        else:\n            data_tuple += (grid_ind,labels,insts,return_fea)\n        return data_tuple\n    \n@nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])',nopython=True,cache=True,parallel = False)\ndef nb_process_label(processed_label,sorted_label_voxel_pair):\n    label_size = 256\n    counter = np.zeros((label_size,),dtype = np.uint16)\n    counter[sorted_label_voxel_pair[0,3]] = 1\n    cur_sear_ind = sorted_label_voxel_pair[0,:3]\n    for i in range(1,sorted_label_voxel_pair.shape[0]):\n        cur_ind = sorted_label_voxel_pair[i,:3]\n        if not np.all(np.equal(cur_ind,cur_sear_ind)):\n            processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)\n            counter = np.zeros((label_size,),dtype = np.uint16)\n            cur_sear_ind = cur_ind\n        counter[sorted_label_voxel_pair[i,3]] += 1\n    processed_label[cur_sear_ind[0],cur_sear_ind[1],cur_sear_ind[2]] = np.argmax(counter)\n    return processed_label\n\n@nb.jit('u1[:,:](u1[:,:],i8[:,:])',nopython=True,cache=True,parallel = False)\ndef nb_process_inst(processed_inst,sorted_inst_voxel_pair):\n    label_size = 256\n    counter = np.zeros((label_size,),dtype = np.uint16)\n    counter[sorted_inst_voxel_pair[0,2]] = 1\n    cur_sear_ind = sorted_inst_voxel_pair[0,:2]\n    for i in range(1,sorted_inst_voxel_pair.shape[0]):\n        cur_ind = sorted_inst_voxel_pair[i,:2]\n        if not np.all(np.equal(cur_ind,cur_sear_ind)):\n            processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)\n            counter = np.zeros((label_size,),dtype = np.uint16)\n            cur_sear_ind = cur_ind\n        counter[sorted_inst_voxel_pair[i,2]] += 1\n    processed_inst[cur_sear_ind[0],cur_sear_ind[1]] = np.argmax(counter)\n    return processed_inst\n\ndef collate_fn_BEV(data):\n    data2stack=np.stack([d[0] for d in data]).astype(np.float32)\n    label2stack=np.stack([d[1] for d in data])\n    center2stack=np.stack([d[2] for d in data])\n    offset2stack=np.stack([d[3] for d in data])\n    grid_ind_stack = [d[4] for d in data]\n    point_label = [d[5] for d in data]\n    point_inst = [d[6] for d in data]\n    xyz = [d[7] for d in data]\n    return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz\n\ndef collate_fn_BEV_test(data):    \n    data2stack=np.stack([d[0] for d in data]).astype(np.float32)\n    label2stack=np.stack([d[1] for d in data])\n    center2stack=np.stack([d[2] for d in data])\n    offset2stack=np.stack([d[3] for d in data])\n    grid_ind_stack = [d[4] for d in data]\n    point_label = [d[5] for d in data]\n    point_inst = [d[6] for d in data]\n    xyz = [d[7] for d in data]\n    index = [d[8] for d in data]\n    return torch.from_numpy(data2stack),torch.from_numpy(label2stack),torch.from_numpy(center2stack),torch.from_numpy(offset2stack),grid_ind_stack,point_label,point_inst,xyz,index\n\n# load Semantic KITTI class info\nwith open(\"semantic-kitti.yaml\", 'r') as stream:\n    semkittiyaml = yaml.safe_load(stream)\nSemKITTI_label_name = dict()\nfor i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]:\n    SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i]"
  },
  {
    "path": "dataloader/instance_augmentation.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport pickle\n\nclass instance_augmentation(object):\n    def __init__(self,instance_pkl_path,thing_list,class_weight,random_flip = False,random_add = False,random_rotate = False,local_transformation = False):\n        self.thing_list = thing_list\n        self.instance_weight = [class_weight[thing_class_num-1] for thing_class_num in thing_list]\n        self.instance_weight = np.asarray(self.instance_weight)/np.sum(self.instance_weight)\n        self.random_flip = random_flip\n        self.random_add = random_add\n        self.random_rotate = random_rotate\n        self.local_transformation = local_transformation\n\n        self.add_num = 5\n\n        with open(instance_pkl_path, 'rb') as f:\n            self.instance_path = pickle.load(f)\n\n    def instance_aug(self, point_xyz, point_label, point_inst, point_feat = None):\n        \"\"\"random rotate and flip each instance independently.\n\n        Args:\n            point_xyz: [N, 3], point location\n            point_label: [N, 1], class label\n            point_inst: [N, 1], instance label\n        \"\"\"        \n        # random add instance to this scan\n        if self.random_add:\n            # choose which instance to add\n            instance_choice = np.random.choice(len(self.thing_list),self.add_num,replace=True,p=self.instance_weight)\n            uni_inst, uni_inst_count = np.unique(instance_choice,return_counts=True)\n            add_idx = 1\n            total_point_num = 0\n            early_break = False\n            for n, count in zip(uni_inst, uni_inst_count):\n                # find random instance\n                random_choice = np.random.choice(len(self.instance_path[self.thing_list[n]]),count)\n                # add to current scan\n                for idx in random_choice:\n                    points = np.fromfile(self.instance_path[self.thing_list[n]][idx], dtype=np.float32).reshape((-1, 4))\n                    add_xyz = points[:,:3]\n                    center = np.mean(add_xyz,axis=0)\n\n                    # need to check occlusion\n                    fail_flag = True\n                    if self.random_rotate:\n                        # random rotate\n                        random_choice = np.random.random(20)*np.pi*2\n                        for r in random_choice:\n                            center_r = self.rotate_origin(center[np.newaxis,...],r)\n                            # check if occluded\n                            if self.check_occlusion(point_xyz,center_r[0]):\n                                fail_flag = False\n                                break\n                        # rotate to empty space\n                        if fail_flag: continue\n                        add_xyz = self.rotate_origin(add_xyz,r)\n                    else:\n                        fail_flag = not self.check_occlusion(point_xyz,center)\n                    if fail_flag: continue\n\n                    add_label = np.ones((points.shape[0],),dtype=np.uint8)*(self.thing_list[n])\n                    add_inst = np.ones((points.shape[0],),dtype=np.uint32)*(add_idx<<16)\n                    point_xyz = np.concatenate((point_xyz,add_xyz),axis=0)\n                    point_label = np.concatenate((point_label,add_label),axis=0)\n                    point_inst = np.concatenate((point_inst,add_inst),axis=0)\n                    if point_feat is not None:\n                        add_fea =  points[:,3:]\n                        if len(add_fea.shape) == 1: add_fea = add_fea[..., np.newaxis]\n                        point_feat = np.concatenate((point_feat,add_fea),axis=0)\n                    add_idx +=1\n                    total_point_num += points.shape[0]\n                    if total_point_num>5000:\n                        early_break=True\n                        break\n                # prevent adding too many points which cause GPU memory error\n                if early_break: break\n\n        # instance mask\n        mask = np.zeros_like(point_label,dtype=bool)\n        for label in self.thing_list:\n            mask[point_label == label] = True\n\n        # create unqiue instance list\n        inst_label = point_inst[mask].squeeze()\n        unique_label = np.unique(inst_label)\n        num_inst = len(unique_label)\n\n        \n        for inst in unique_label:\n            # get instance index\n            index = np.where(point_inst == inst)[0]\n            # skip small instance\n            if index.size<10: continue\n            # get center\n            center = np.mean(point_xyz[index,:],axis=0)\n\n            if self.local_transformation:\n                # random translation and rotation\n                point_xyz[index,:] = self.local_tranform(point_xyz[index,:],center)\n            \n            # random flip instance based on it center \n            if self.random_flip:\n                # get axis\n                long_axis = [center[0], center[1]]/(center[0]**2+center[1]**2)**0.5\n                short_axis = [-long_axis[1],long_axis[0]]\n                # random flip\n                flip_type = np.random.choice(5,1)\n                if flip_type==3:\n                    point_xyz[index,:2] = self.instance_flip(point_xyz[index,:2],[long_axis,short_axis],[center[0], center[1]],flip_type)\n            \n            # 20% random rotate\n            random_num = np.random.random_sample()\n            if self.random_rotate:\n                if random_num>0.8 and inst & 0xFFFF > 0:\n                    random_choice = np.random.random(20)*np.pi*2\n                    fail_flag = True\n                    for r in random_choice:\n                        center_r = self.rotate_origin(center[np.newaxis,...],r)\n                        # check if occluded\n                        if self.check_occlusion(np.delete(point_xyz, index, axis=0),center_r[0]):\n                            fail_flag = False\n                            break\n                    if not fail_flag:\n                        # rotate to empty space\n                        point_xyz[index,:] = self.rotate_origin(point_xyz[index,:],r)\n\n        if len(point_label.shape) == 1: point_label = point_label[..., np.newaxis]\n        if len(point_inst.shape) == 1: point_inst = point_inst[..., np.newaxis]\n        if point_feat is not None:\n            return point_xyz,point_label,point_inst,point_feat\n        else:\n            return point_xyz,point_label,point_inst\n\n    def instance_flip(self, points,axis,center,flip_type = 1):\n        points = points[:]-center\n        if flip_type == 1:\n            # rotate 180 degree\n            points = -points+center\n        elif flip_type == 2:\n            # flip over long axis\n            a = axis[0][0]\n            b = axis[0][1]\n            flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])\n            points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))\n            points = np.transpose(points, (1, 0))+center\n        elif flip_type == 3:\n            # flip over short axis\n            a = axis[1][0]\n            b = axis[1][1]\n            flip_matrix = np.array([[b**2 - a**2, -2 * a * b],[-2 * a * b, a**2 - b**2]])\n            points = np.matmul(flip_matrix,np.transpose(points, (1, 0)))\n            points = np.transpose(points, (1, 0))+center\n\n        return points\n\n    def check_occlusion(self,points,center,min_dist=2):\n        'check if close to a point'\n        if points.ndim == 1:\n            dist = np.linalg.norm(points[np.newaxis,:]-center,axis=1)\n        else:\n            dist = np.linalg.norm(points-center,axis=1)\n        return np.all(dist>min_dist)\n\n    def rotate_origin(self,xyz,radians):\n        'rotate a point around the origin'\n        x = xyz[:,0]\n        y = xyz[:,1]\n        new_xyz = xyz.copy()\n        new_xyz[:,0] = x * np.cos(radians) + y * np.sin(radians)\n        new_xyz[:,1] = -x * np.sin(radians) + y * np.cos(radians)\n        return new_xyz\n\n    def local_tranform(self,xyz,center):\n        'translate and rotate point cloud according to its center'\n        # random xyz\n        loc_noise = np.random.normal(scale = 0.25, size=(1,3))\n        # random angle\n        rot_noise = np.random.uniform(-np.pi/20, np.pi/20)\n\n        xyz = xyz-center\n        xyz = self.rotate_origin(xyz,rot_noise)\n        xyz = xyz+loc_noise\n        \n        return xyz+center\n"
  },
  {
    "path": "dataloader/process_panoptic.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport numpy as np\n\nclass PanopticLabelGenerator(object):\n    def __init__(self,grid_size,sigma=5,polar=False):\n        \"\"\"Initialize panoptic ground truth generator\n\n        Args:\n            grid_size: voxel size.\n            sigma (int, optional):  Gaussian distribution paramter. Create heatmap in +-3*sigma window. Defaults to 5.\n            polar (bool, optional): Is under polar coordinate. Defaults to False.\n        \"\"\"        \n        self.grid_size = grid_size\n        self.polar = polar\n\n        self.sigma = sigma\n        size = 6 * sigma + 3\n        x = np.arange(0, size, 1, float)\n        y = x[:, np.newaxis]\n        x0, y0 = 3 * sigma + 1, 3 * sigma + 1\n        self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n    \n    def __call__(self,inst,xyz,voxel_inst,voxel_position,label_dict,min_bound,intervals):\n        \"\"\"Generate instance center and offset ground truth\n\n        Args:\n            inst : instance panoptic label (N)\n            xyz : point location (N x 3)\n            voxel_inst : voxel panoptic label on the BEV (H x W)\n            voxel_position : voxel location on the BEV (3 x H x W)\n            label_dict : unqiue instance label dict\n            min_bound : space minimal bound\n            intervals : voxelization intervals\n\n        Returns:\n            center, center_pts, offset\n        \"\"\"        \n        height, width = self.grid_size[0],self.grid_size[1]\n\n        center = np.zeros((1, height, width), dtype=np.float32)\n        center_pts = []\n        offset = np.zeros((2, height, width), dtype=np.float32)\n        #skip empty instances\n        if inst.size < 2: return center, center_pts, offset\n        # find unique instances\n        inst_labels = np.unique(inst)\n        for inst_label in inst_labels:\n            # get mask for each unique instance\n            mask = np.where(inst == inst_label)\n            voxel_mask = np.where(voxel_inst == label_dict[inst_label])\n            # get center\n            center_x, center_y = np.mean(xyz[mask,0]), np.mean(xyz[mask,1])\n            if self.polar:\n                # convert to polar coordinate\n                center_x_pol, center_y_pol = np.sqrt(center_x**2 + center_y**2),np.arctan2(center_y,center_x)\n                center_x = center_x_pol\n                center_y = center_y_pol\n\n            # generate center heatmap\n            x, y = int(np.floor((center_x-min_bound[0])/intervals[0])), int(np.floor((center_y-min_bound[1])/intervals[1]))\n            center_pts.append([x, y])\n            # outside image boundary\n            if x < 0 or y < 0 or \\\n                    x >= height or y >= width:\n                continue\n            sigma = self.sigma\n            # upper left\n            ul = int(np.round(x - 3 * sigma - 1)), int(np.round(y - 3 * sigma - 1))\n            # bottom right\n            br = int(np.round(x + 3 * sigma + 2)), int(np.round(y + 3 * sigma + 2))\n\n            if self.polar:\n                c, d = max(0, -ul[0]), min(br[0], height) - ul[0]\n                a, b = 0, br[1] - ul[1]\n\n                cc, dd = max(0, ul[0]), min(br[0], height)\n                angle_list = [angle_id % width for angle_id in range(ul[1],br[1])]\n                center[0, cc:dd, angle_list] = np.maximum(\n                    center[0, cc:dd, angle_list], np.transpose(self.g[c:d,a:b]))\n            else:\n                c, d = max(0, -ul[0]), min(br[0], height) - ul[0]\n                a, b = max(0, -ul[1]), min(br[1], width) - ul[1]\n\n                cc, dd = max(0, ul[0]), min(br[0], height)\n                aa, bb = max(0, ul[1]), min(br[1], width)\n                center[0, cc:dd, aa:bb] = np.maximum(\n                    center[0, cc:dd, aa:bb], self.g[c:d,a:b])\n\n            if self.polar:\n                # generate offset (2, h, w) -> (y-dir, x-dir)\n                offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]\n                offset[1,voxel_mask[0],voxel_mask[1]] = ((center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]]+np.pi)%(2*np.pi) - np.pi)/intervals[1]\n            else:\n                # generate offset (2, h, w) -> (y-dir, x-dir)\n                offset[0,voxel_mask[0],voxel_mask[1]] = (center_x - voxel_position[0,voxel_mask[0],voxel_mask[1]])/intervals[0]\n                offset[1,voxel_mask[0],voxel_mask[1]] = (center_y - voxel_position[1,voxel_mask[0],voxel_mask[1]])/intervals[1]\n\n        # print('gt center',center_pts)\n\n        return center, center_pts, offset"
  },
  {
    "path": "instance_preprocess.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport argparse\nfrom dataloader.dataset import SemKITTI\n\nif __name__ == '__main__':\n    # instance preprocessing\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('-d', '--data_path', default='data')\n    parser.add_argument('-o', '--out_path', default='data')\n\n    args = parser.parse_args()\n\n    train_pt_dataset = SemKITTI(args.data_path + '/sequences/', imageset = 'train', return_ref = True)\n    train_pt_dataset.save_instance(args.out_path)\n    print('instance preprocessing finished.')"
  },
  {
    "path": "network/BEV_Unet.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dropblock import DropBlock2D\n\n\nclass BEV_Unet(nn.Module):\n\n    def __init__(self,n_class,n_height,dilation = 1,group_conv=False,input_batch_norm = False,dropout = 0.,circular_padding = False, dropblock = True, use_vis_fea=False):\n        super(BEV_Unet, self).__init__()\n        self.n_class = n_class\n        self.n_height = n_height\n        if use_vis_fea:\n            self.network = UNet(n_class*n_height,2*n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)\n        else:\n            self.network = UNet(n_class*n_height,n_height,dilation,group_conv,input_batch_norm,dropout,circular_padding,dropblock)\n\n    def forward(self, x):\n        x,center,offset = self.network(x)\n        \n        x = x.permute(0,2,3,1)\n        new_shape = list(x.size())[:3] + [self.n_height,self.n_class]\n        x = x.view(new_shape)\n        x = x.permute(0,4,1,2,3)\n\n        return x,center,offset\n    \nclass UNet(nn.Module):\n    def __init__(self, n_class,n_height,dilation,group_conv,input_batch_norm, dropout,circular_padding,dropblock):\n        super(UNet, self).__init__()\n        # encoder\n        self.inc = inconv(n_height, 64, dilation, input_batch_norm, circular_padding)\n        self.down1 = down(64, 128, dilation, group_conv, circular_padding)\n        self.down2 = down(128, 256, dilation, group_conv, circular_padding)\n        self.down3 = down(256, 512, dilation, group_conv, circular_padding)\n        self.down4 = down(512, 512, dilation, group_conv, circular_padding)\n\n        # semantic decoder\n        self.up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.up4 = up(128, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.dropout = nn.Dropout(p=0. if dropblock else dropout)\n        # semantic head\n        self.outc = outconv(64, n_class)\n\n        # instance decoder\n        # self.i_up1 = up(1024, 256, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        # self.i_up2 = up(512, 128, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        # self.i_up3 = up(256, 64, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        # self.i_up4 = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.i_up4_center = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        self.i_up4_offset = up(128, 32, circular_padding, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)\n        # instance head\n        self.i_outc_center = outconv(32, 1)\n        self.i_outc_offset = outconv(32, 2)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        # semantic\n        x = self.up1(x5, x4)\n        x = self.up2(x, x3)\n        x = self.up3(x, x2)\n        s_x = self.up4(x, x1)\n        s_x = self.outc(self.dropout(s_x))\n        # instance\n        # i_x = self.i_up1(x5, x4)\n        # i_x = self.i_up2(i_x, x3)\n        # i_x = self.i_up3(i_x, x2)x\n\n        # i_x = self.i_up4(i_x, x1)\n        i_x_center = self.i_up4_center(x, x1)\n        i_x_center = self.i_outc_center(self.dropout(i_x_center))\n\n        i_x_offset = self.i_up4_offset(x, x1)\n        i_x_offset = self.i_outc_offset(self.dropout(i_x_offset))\n\n        return s_x, i_x_center, i_x_offset\n\nclass double_conv(nn.Module):\n    '''(conv => BN => ReLU) * 2'''\n    def __init__(self, in_ch, out_ch,group_conv,dilation=1):\n        super(double_conv, self).__init__()\n        if group_conv:\n            self.conv = nn.Sequential(\n                nn.Conv2d(in_ch, out_ch, 3, padding=1,groups = min(out_ch,in_ch)),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True),\n                nn.Conv2d(out_ch, out_ch, 3, padding=1,groups = out_ch),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n        else:\n            self.conv = nn.Sequential(\n                nn.Conv2d(in_ch, out_ch, 3, padding=1),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True),\n                nn.Conv2d(out_ch, out_ch, 3, padding=1),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\nclass double_conv_circular(nn.Module):\n    '''(conv => BN => ReLU) * 2'''\n    def __init__(self, in_ch, out_ch,group_conv,dilation=1):\n        super(double_conv_circular, self).__init__()\n        if group_conv:\n            self.conv1 = nn.Sequential(\n                nn.Conv2d(in_ch, out_ch, 3, padding=(1,0),groups = min(out_ch,in_ch)),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n            self.conv2 = nn.Sequential(\n                nn.Conv2d(out_ch, out_ch, 3, padding=(1,0),groups = out_ch),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n        else:\n            self.conv1 = nn.Sequential(\n                nn.Conv2d(in_ch, out_ch, 3, padding=(1,0)),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n            self.conv2 = nn.Sequential(\n                nn.Conv2d(out_ch, out_ch, 3, padding=(1,0)),\n                nn.BatchNorm2d(out_ch),\n                nn.LeakyReLU(inplace=True)\n            )\n\n    def forward(self, x):\n        #add circular padding\n        x = F.pad(x,(1,1,0,0),mode = 'circular')\n        x = self.conv1(x)\n        x = F.pad(x,(1,1,0,0),mode = 'circular')\n        x = self.conv2(x)\n        return x\n\n\nclass inconv(nn.Module):\n    def __init__(self, in_ch, out_ch, dilation, input_batch_norm, circular_padding):\n        super(inconv, self).__init__()\n        if input_batch_norm:\n            if circular_padding:\n                self.conv = nn.Sequential(\n                    nn.BatchNorm2d(in_ch),\n                    double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)\n                )\n            else:\n                self.conv = nn.Sequential(\n                    nn.BatchNorm2d(in_ch),\n                    double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)\n                )\n        else:\n            if circular_padding:\n                self.conv = double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation)\n            else:\n                self.conv = double_conv(in_ch, out_ch,group_conv = False,dilation = dilation)\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass down(nn.Module):\n    def __init__(self, in_ch, out_ch, dilation, group_conv, circular_padding):\n        super(down, self).__init__()\n        if circular_padding:\n            self.mpconv = nn.Sequential(\n                nn.MaxPool2d(2),\n                double_conv_circular(in_ch, out_ch,group_conv = group_conv,dilation = dilation)\n            )\n        else:\n            self.mpconv = nn.Sequential(\n                nn.MaxPool2d(2),\n                double_conv(in_ch, out_ch,group_conv = group_conv,dilation = dilation)\n            )                \n\n    def forward(self, x):\n        x = self.mpconv(x)\n        return x\n\n\nclass up(nn.Module):\n    def __init__(self, in_ch, out_ch, circular_padding, bilinear=True, group_conv=False, use_dropblock = False, drop_p = 0.5):\n        super(up, self).__init__()\n\n        #  would be a nice idea if the upsampling could be learned too,\n        #  but my machine do not have enough memory to handle all those weights\n        if bilinear:\n            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n        elif group_conv:\n            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2,groups = in_ch//2)\n        else:\n            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)\n\n        if circular_padding:\n            self.conv = double_conv_circular(in_ch, out_ch,group_conv = group_conv)\n        else:\n            self.conv = double_conv(in_ch, out_ch,group_conv = group_conv)\n\n        self.use_dropblock = use_dropblock\n        if self.use_dropblock:\n            self.dropblock = DropBlock2D(block_size=7, drop_prob=drop_p)\n\n    def forward(self, x1, x2):\n        x1 = self.up(x1)\n        \n        # input is CHW\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n\n        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,\n                        diffY // 2, diffY - diffY//2))\n        \n        # for padding issues, see \n        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n\n        x = torch.cat([x2, x1], dim=1)\n        x = self.conv(x)\n        if self.use_dropblock:\n            x = self.dropblock(x)\n        return x\n\n\nclass outconv(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(outconv, self).__init__()\n        self.conv = nn.Conv2d(in_ch, out_ch, 1)\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x"
  },
  {
    "path": "network/__init__.py",
    "content": ""
  },
  {
    "path": "network/instance_post_processing.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn.functional as F\nimport torch_scatter\n\n\ndef find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=5, top_k=None, polar=False):\n    \"\"\"\n    Find the center points from the center heatmap.\n    Arguments:\n        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,\n            for consistent, we only support N=1.\n        threshold: A Float, threshold applied to center heatmap score.\n        nms_kernel: An Integer, NMS max pooling kernel size.\n        top_k: An Integer, top k centers to keep.\n    Returns:\n        A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).\n    \"\"\"\n    if ctr_hmp.size(0) != 1:\n        raise ValueError('Only supports inference for batch size = 1')\n\n    # thresholding, setting values below threshold to -1\n    ctr_hmp = F.threshold(ctr_hmp, threshold, -1)\n\n    # NMS\n    if polar:\n        nms_padding = (nms_kernel - 1) // 2\n        ctr_hmp_max_pooled = F.pad(ctr_hmp,(nms_padding,nms_padding,0,0),mode = 'circular')\n        ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp_max_pooled, kernel_size=nms_kernel, stride=1, padding=(nms_padding,0))\n    else:\n        nms_padding = (nms_kernel - 1) // 2\n        ctr_hmp_max_pooled = F.max_pool2d(ctr_hmp, kernel_size=nms_kernel, stride=1, padding=nms_padding)\n    ctr_hmp[ctr_hmp != ctr_hmp_max_pooled] = -1\n\n    # squeeze first two dimensions\n    ctr_hmp = ctr_hmp.squeeze()\n    assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.'\n\n    # find non-zero elements\n    ctr_all = torch.nonzero(ctr_hmp > 0)\n    if top_k is None:\n        return ctr_all\n    elif ctr_all.size(0) < top_k:\n        return ctr_all\n    else:\n        # find top k centers.\n        top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k)\n        return torch.nonzero(ctr_hmp > top_k_scores[-1])\n\n\ndef group_pixels(ctr, offsets, polar=False):\n    \"\"\"\n    Gives each pixel in the image an instance id.\n    Arguments:\n        ctr: A Tensor of shape [K, 2] where K is the number of center points. The order of second dim is (y, x).\n        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,\n            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).\n    Returns:\n        A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).\n    \"\"\"\n    if offsets.size(0) != 1:\n        raise ValueError('Only supports inference for batch size = 1')\n\n    offsets = offsets.squeeze(0)\n    height, width = offsets.size()[1:]\n\n    # generates a coordinate map, where each location is the coordinate of that loc\n    y_coord = torch.arange(height, dtype=offsets.dtype, device=offsets.device).repeat(1, width, 1).transpose(1, 2)\n    x_coord = torch.arange(width, dtype=offsets.dtype, device=offsets.device).repeat(1, height, 1)\n    coord = torch.cat((y_coord, x_coord), dim=0)\n\n    ctr_loc = coord + offsets\n    ctr_loc = ctr_loc.reshape((2, height * width)).transpose(1, 0)\n\n    # ctr: [K, 2] -> [K, 1, 2]\n    # ctr_loc = [H*W, 2] -> [1, H*W, 2]\n    ctr = ctr.unsqueeze(1)\n    ctr_loc = ctr_loc.unsqueeze(0)\n\n    # distance: [K, H*W]\n    distance = ctr - ctr_loc\n    if polar:\n        distance[:,:,0] = torch.add(torch.fmod(torch.add(distance[:,:,0],width/2),width),-width/2)\n    distance = torch.norm(distance, dim=-1)\n\n    # finds center with minimum distance at each location, offset by 1, to reserve id=0 for stuff\n    instance_id = torch.argmin(distance, dim=0).reshape((1, height, width)) + 1\n    return instance_id\n\n\ndef get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, threshold=0.1, nms_kernel=5, top_k=None,\n                              thing_seg=None, polar=False):\n    \"\"\"\n    Post-processing for instance segmentation, gets class agnostic instance id map.\n    Arguments:\n        sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.\n        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,\n            for consistent, we only support N=1.\n        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,\n            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).\n        thing_list: A List of thing class id.\n        threshold: A Float, threshold applied to center heatmap score.\n        nms_kernel: An Integer, NMS max pooling kernel size.\n        top_k: An Integer, top k centers to keep.\n        thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask, if not provided, inference from\n            semantic prediction.\n    Returns:\n        A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).\n        A Tensor of shape [1, K, 2] where K is the number of center points. The order of second dim is (y, x).\n    \"\"\"\n    # if thing_seg is None:\n    #     # gets foreground segmentation\n    #     thing_seg = torch.zeros_like(sem_seg)\n    #     for thing_class in thing_list:\n    #         thing_seg[sem_seg == thing_class] = 1\n    # if thing_seg.dim() == 4:\n    #     # [1, H, W, Z] --> [1, H, W]\n    #     thing_seg = torch.max(thing_seg,dim=3)\n\n    ctr = find_instance_center(ctr_hmp, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k, polar=polar)\n    if ctr.size(0) == 0:\n        return torch.zeros_like(thing_seg[:,:,:,0]), ctr.unsqueeze(0)\n    ins_seg = group_pixels(ctr, offsets, polar=polar)\n    return ins_seg, ctr.unsqueeze(0)\n\n\ndef merge_semantic_and_instance(sem_seg, sem, ins_seg, label_divisor, thing_list, void_label,thing_seg):\n    \"\"\"\n    Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic\n        instance segmentation label.\n    Arguments:\n        sem_seg: A Tensor of shape [1, H, W, Z], predicted semantic label.\n        sem: A Tensor of shape [1, C, H, W, Z], predicted semantic logit.\n        ins_seg: A Tensor of shape [1, H, W], predicted instance label.\n        label_divisor: An Integer, used to convert panoptic id = semantic id * label_divisor + instance_id.\n        thing_list: A List of thing class id.\n        void_label: An Integer, indicates the region has no confident prediction.\n        thing_seg: A Tensor of shape [1, H, W, Z], predicted foreground mask.\n    Returns:\n        A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel).\n    Raises:\n        ValueError, if batch size is not 1.\n    \"\"\"\n    # In case thing mask does not align with semantic prediction\n    # semantic_thing_seg = torch.zeros_like(sem_seg,dtype=torch.bool)\n    # for thing_class in thing_list:\n    #     semantic_thing_seg[sem_seg == thing_class] = True\n    \n    # try to avoid the for loop\n    semantic_thing_seg = sem_seg<=max(thing_list)\n\n    ins_seg = torch.unsqueeze(ins_seg,3).expand_as(sem_seg)\n    thing_mask = (ins_seg > 0) & semantic_thing_seg & thing_seg\n    if not torch.nonzero(thing_mask).size(0) == 0:\n        sem_sum = torch_scatter.scatter_add(sem.permute(0,2,3,4,1)[thing_mask],ins_seg[thing_mask],dim=0)\n        class_id = torch.argmax(sem_sum[:,:max(thing_list)],dim=1)\n        sem_seg[thing_mask] = (ins_seg[thing_mask] * label_divisor) + class_id[ins_seg[thing_mask]]+1\n    else:\n        sem_seg[semantic_thing_seg & thing_seg] = void_label\n    return sem_seg\n\n\ndef get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor=2**16, void_label=0,\n                              threshold=0.1, nms_kernel=5, top_k=100, foreground_mask=None, polar=False):\n    \"\"\"\n    Post-processing for panoptic segmentation.\n    Arguments:\n        sem: A Tensor of shape [N, C, H, W, Z] of raw semantic output, where N is the batch size, for consistent,\n            we only support N=1.\n        ctr_hmp: A Tensor of shape [N, 1, H, W] of raw center heatmap output, where N is the batch size,\n            for consistent, we only support N=1.\n        offsets: A Tensor of shape [N, 2, H, W] of raw offset output, where N is the batch size,\n            for consistent, we only support N=1. The order of second dim is (offset_y, offset_x).\n        thing_list: A List of thing class id.\n        label_divisor: An Integer, used to convert panoptic id = instance_id * label_divisor + semantic_id.\n        void_label: An Integer, indicates the region has no confident prediction.\n        threshold: A Float, threshold applied to center heatmap score.\n        nms_kernel: An Integer, NMS max pooling kernel size.\n        top_k: An Integer, top k centers to keep.\n        foreground_mask: A processed Tensor of shape [N, H, W, Z], we only support N=1.\n    Returns:\n        A Tensor of shape [1, H, W, Z] (to be gathered by distributed data parallel), int64.\n    Raises:\n        ValueError, if batch size is not 1.\n    \"\"\"\n    if sem.dim() != 5 and sem.dim() != 4:\n        raise ValueError('Semantic prediction with un-supported dimension: {}.'.format(sem.dim()))\n    if sem.dim() == 5 and sem.size(0) != 1:\n        raise ValueError('Only supports inference for batch size = 1')\n    if ctr_hmp.size(0) != 1:\n        raise ValueError('Only supports inference for batch size = 1')\n    if offsets.size(0) != 1:\n        raise ValueError('Only supports inference for batch size = 1')\n    if foreground_mask is not None:\n        if foreground_mask.dim() != 4:\n            raise ValueError('Foreground prediction with un-supported dimension: {}.'.format(sem.dim()))\n\n    if sem.dim() == 5:\n        semantic = torch.argmax(sem, dim=1)\n        # shift back to original label idx \n        semantic = torch.add(semantic, 1)\n        sem = F.softmax(sem)\n    else:\n        semantic = sem.type(torch.ByteTensor).cuda()\n        # shift back to original label idx \n        semantic = torch.add(semantic, 1).type(torch.LongTensor).cuda()\n        one_hot = torch.zeros((sem.size(0),torch.max(semantic).item()+1,sem.size(1),sem.size(2),sem.size(3))).cuda()\n        sem = one_hot.scatter_(1,torch.unsqueeze(semantic,1),1.)\n        sem = sem[:,1:,:,:,:]\n\n\n    if foreground_mask is not None:\n        thing_seg = foreground_mask\n    else:\n        thing_seg = None\n\n    \n    instance, center = get_instance_segmentation(semantic, ctr_hmp, offsets, thing_list,\n                                                 threshold=threshold, nms_kernel=nms_kernel, top_k=top_k,\n                                                 thing_seg=thing_seg, polar=polar)\n    panoptic = merge_semantic_and_instance(semantic, sem, instance, label_divisor, thing_list, void_label, thing_seg)\n\n    return panoptic, center\n"
  },
  {
    "path": "network/loss.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport numpy as np\nimport torch\nfrom .lovasz_losses import lovasz_softmax\n\ndef _neg_loss(pred, gt):\n    ''' Modified focal loss. Exactly the same as CornerNet.\n        Runs faster and costs a little bit more memory\n        (https://github.com/tianweiy/CenterPoint)\n    Arguments:\n        pred (batch x c x h x w)\n        gt (batch x c x h x w)\n    '''\n    pos_inds = gt.eq(1).float()\n    neg_inds = gt.lt(1).float()\n\n    neg_weights = torch.pow(1 - gt, 4)\n\n    # loss = 0\n\n    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds\n    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds\n    return - (pos_loss + neg_loss)\n\nclass FocalLoss(torch.nn.Module):\n    '''nn.Module warpper for focal loss'''\n    def __init__(self):\n        super(FocalLoss, self).__init__()\n        self.neg_loss = _neg_loss\n\n    def forward(self, out, target):\n        return self.neg_loss(out, target)\n\nclass panoptic_loss(torch.nn.Module):\n    def __init__(self, ignore_label = 255, center_loss_weight = 100, offset_loss_weight = 1, center_loss = 'MSE', offset_loss = 'L1'):\n        super(panoptic_loss, self).__init__()\n        self.CE_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)\n        assert center_loss in ['MSE','FocalLoss']\n        assert offset_loss in ['L1','SmoothL1']\n        if center_loss == 'MSE':\n            self.center_loss_fn = torch.nn.MSELoss()\n        elif center_loss == 'FocalLoss':\n            self.center_loss_fn = FocalLoss()\n        else: raise NotImplementedError\n        if offset_loss == 'L1':\n            self.offset_loss_fn = torch.nn.L1Loss()\n        elif offset_loss == 'SmoothL1':\n            self.offset_loss_fn = torch.nn.SmoothL1Loss()\n        else: raise NotImplementedError\n        self.center_loss_weight = center_loss_weight\n        self.offset_loss_weight = offset_loss_weight\n\n        print('Using '+ center_loss +' for heatmap regression, weight: '+str(center_loss_weight))\n        print('Using '+ offset_loss +' for offset regression, weight: '+str(offset_loss_weight))\n\n        self.lost_dict={'semantic_loss':[],\n                        'heatmap_loss':[],\n                        'offset_loss':[]}\n\n    def reset_loss_dict(self):\n        self.lost_dict={'semantic_loss':[],\n                        'heatmap_loss':[],\n                        'offset_loss':[]}\n\n    def forward(self,prediction,center,offset,gt_label,gt_center,gt_offset,save_loss = True):\n        # semantic loss\n        loss = lovasz_softmax(torch.nn.functional.softmax(prediction), gt_label,ignore=255) + self.CE_loss(prediction,gt_label)\n        if save_loss:\n            self.lost_dict['semantic_loss'].append(loss.item())\n        # center heatmap loss\n        center_mask = (gt_center>0) | (torch.min(torch.unsqueeze(gt_label, 1),dim=4)[0]<255)\n        center_loss = self.center_loss_fn(center,gt_center) * center_mask\n        # safe division\n        if center_mask.sum() > 0:\n            center_loss = center_loss.sum() / center_mask.sum() * self.center_loss_weight\n        else:\n            center_loss = center_loss.sum() * 0\n        if save_loss:\n            self.lost_dict['heatmap_loss'].append(center_loss.item())\n        loss += center_loss\n        # offset loss\n        offset_mask = gt_offset != 0\n        offset_loss = self.offset_loss_fn(offset,gt_offset) * offset_mask\n        # safe division\n        if offset_mask.sum() > 0:\n            offset_loss = offset_loss.sum() / offset_mask.sum() * self.offset_loss_weight\n        else:\n            offset_loss = offset_loss.sum() * 0\n        if save_loss:\n            self.lost_dict['offset_loss'].append(offset_loss.item())\n        loss += offset_loss\n        return loss"
  },
  {
    "path": "network/lovasz_losses.py",
    "content": "\"\"\"\nLovasz-Softmax and Jaccard hinge loss in PyTorch\nMaxim Berman 2018 ESAT-PSI KU Leuven (MIT License)\n\"\"\"\n\nfrom __future__ import print_function, division\n\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport numpy as np\ntry:\n    from itertools import  ifilterfalse\nexcept ImportError: # py3k\n    from itertools import  filterfalse as ifilterfalse\n\ndef lovasz_grad(gt_sorted):\n    \"\"\"\n    Computes gradient of the Lovasz extension w.r.t sorted errors\n    See Alg. 1 in paper\n    \"\"\"\n    p = len(gt_sorted)\n    gts = gt_sorted.sum()\n    intersection = gts - gt_sorted.float().cumsum(0)\n    union = gts + (1 - gt_sorted).float().cumsum(0)\n    jaccard = 1. - intersection / union\n    if p > 1: # cover 1-pixel case\n        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]\n    return jaccard\n\n\ndef iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):\n    \"\"\"\n    IoU for foreground class\n    binary: 1 foreground, 0 background\n    \"\"\"\n    if not per_image:\n        preds, labels = (preds,), (labels,)\n    ious = []\n    for pred, label in zip(preds, labels):\n        intersection = ((label == 1) & (pred == 1)).sum()\n        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()\n        if not union:\n            iou = EMPTY\n        else:\n            iou = float(intersection) / float(union)\n        ious.append(iou)\n    iou = mean(ious)    # mean accross images if per_image\n    return 100 * iou\n\n\ndef iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):\n    \"\"\"\n    Array of IoU for each (non ignored) class\n    \"\"\"\n    if not per_image:\n        preds, labels = (preds,), (labels,)\n    ious = []\n    for pred, label in zip(preds, labels):\n        iou = []    \n        for i in range(C):\n            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)\n                intersection = ((label == i) & (pred == i)).sum()\n                union = ((label == i) | ((pred == i) & (label != ignore))).sum()\n                if not union:\n                    iou.append(EMPTY)\n                else:\n                    iou.append(float(intersection) / float(union))\n        ious.append(iou)\n    ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image\n    return 100 * np.array(ious)\n\n\n# --------------------------- BINARY LOSSES ---------------------------\n\n\ndef lovasz_hinge(logits, labels, per_image=True, ignore=None):\n    \"\"\"\n    Binary Lovasz hinge loss\n      logits: [B, H, W] Variable, logits at each pixel (between -\\infty and +\\infty)\n      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)\n      per_image: compute the loss per image instead of per batch\n      ignore: void class id\n    \"\"\"\n    if per_image:\n        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))\n                          for log, lab in zip(logits, labels))\n    else:\n        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))\n    return loss\n\n\ndef lovasz_hinge_flat(logits, labels):\n    \"\"\"\n    Binary Lovasz hinge loss\n      logits: [P] Variable, logits at each prediction (between -\\infty and +\\infty)\n      labels: [P] Tensor, binary ground truth labels (0 or 1)\n      ignore: label to ignore\n    \"\"\"\n    if len(labels) == 0:\n        # only void pixels, the gradients should be 0\n        return logits.sum() * 0.\n    signs = 2. * labels.float() - 1.\n    errors = (1. - logits * Variable(signs))\n    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)\n    perm = perm.data\n    gt_sorted = labels[perm]\n    grad = lovasz_grad(gt_sorted)\n    loss = torch.dot(F.relu(errors_sorted), Variable(grad))\n    return loss\n\n\ndef flatten_binary_scores(scores, labels, ignore=None):\n    \"\"\"\n    Flattens predictions in the batch (binary case)\n    Remove labels equal to 'ignore'\n    \"\"\"\n    scores = scores.view(-1)\n    labels = labels.view(-1)\n    if ignore is None:\n        return scores, labels\n    valid = (labels != ignore)\n    vscores = scores[valid]\n    vlabels = labels[valid]\n    return vscores, vlabels\n\n\nclass StableBCELoss(torch.nn.modules.Module):\n    def __init__(self):\n         super(StableBCELoss, self).__init__()\n    def forward(self, input, target):\n         neg_abs = - input.abs()\n         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()\n         return loss.mean()\n\n\ndef binary_xloss(logits, labels, ignore=None):\n    \"\"\"\n    Binary Cross entropy loss\n      logits: [B, H, W] Variable, logits at each pixel (between -\\infty and +\\infty)\n      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)\n      ignore: void class id\n    \"\"\"\n    logits, labels = flatten_binary_scores(logits, labels, ignore)\n    loss = StableBCELoss()(logits, Variable(labels.float()))\n    return loss\n\n\n# --------------------------- MULTICLASS LOSSES ---------------------------\n\n\ndef lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).\n              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].\n      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n      per_image: compute the loss per image instead of per batch\n      ignore: void class labels\n    \"\"\"\n    if per_image:\n        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)\n                          for prob, lab in zip(probas, labels))\n    else:\n        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)\n    return loss\n\n\ndef lovasz_softmax_flat(probas, labels, classes='present'):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)\n      labels: [P] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n    \"\"\"\n    if probas.numel() == 0:\n        # only void pixels, the gradients should be 0\n        return probas * 0.\n    C = probas.size(1)\n    losses = []\n    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes\n    for c in class_to_sum:\n        fg = (labels == c).float() # foreground for class c\n        if (classes is 'present' and fg.sum() == 0):\n            continue\n        if C == 1:\n            if len(classes) > 1:\n                raise ValueError('Sigmoid output possible only with 1 class')\n            class_pred = probas[:, 0]\n        else:\n            class_pred = probas[:, c]\n        errors = (Variable(fg) - class_pred).abs()\n        errors_sorted, perm = torch.sort(errors, 0, descending=True)\n        perm = perm.data\n        fg_sorted = fg[perm]\n        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))\n    return mean(losses)\n\n\ndef flatten_probas(probas, labels, ignore=None):\n    \"\"\"\n    Flattens predictions in the batch\n    \"\"\"\n    if probas.dim() == 3:\n        # assumes output of a sigmoid layer\n        B, H, W = probas.size()\n        probas = probas.view(B, 1, H, W)\n    elif probas.dim() == 5:\n        #3D segmentation\n        B, C, L, H, W = probas.size()\n        probas = probas.contiguous().view(B, C, L, H*W)\n    B, C, H, W = probas.size()\n    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C\n    labels = labels.view(-1)\n    if ignore is None:\n        return probas, labels\n    valid = (labels != ignore)\n    vprobas = probas[valid.nonzero().squeeze()]\n    vlabels = labels[valid]\n    return vprobas, vlabels\n\ndef xloss(logits, labels, ignore=None):\n    \"\"\"\n    Cross entropy loss\n    \"\"\"\n    return F.cross_entropy(logits, Variable(labels), ignore_index=255)\n\ndef jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):\n    \"\"\"\n    Something wrong with this loss\n    Multi-class Lovasz-Softmax loss\n      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).\n              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].\n      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n      per_image: compute the loss per image instead of per batch\n      ignore: void class labels\n    \"\"\"\n    vprobas, vlabels = flatten_probas(probas, labels, ignore)\n    \n    \n    true_1_hot = torch.eye(vprobas.shape[1])[vlabels]\n    \n    if bk_class:\n        one_hot_assignment = torch.ones_like(vlabels)\n        one_hot_assignment[vlabels == bk_class] = 0\n        one_hot_assignment = one_hot_assignment.float().unsqueeze(1)\n        true_1_hot = true_1_hot*one_hot_assignment\n    \n    true_1_hot = true_1_hot.to(vprobas.device)\n    intersection = torch.sum(vprobas * true_1_hot)\n    cardinality = torch.sum(vprobas + true_1_hot)\n    loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()\n    return (1-loss)*smooth\n\ndef hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):\n    \"\"\"\n    Multi-class Hinge Jaccard loss\n      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).\n              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].\n      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n      ignore: void class labels\n    \"\"\"\n    vprobas, vlabels = flatten_probas(probas, labels, ignore)\n    C = vprobas.size(1)\n    losses = []\n    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes\n    for c in class_to_sum:\n        if c in vlabels:\n            c_sample_ind = vlabels == c\n            cprobas = vprobas[c_sample_ind,:]\n            non_c_ind =np.array([a for a in class_to_sum if a != c])\n            class_pred = cprobas[:,c]\n            max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]\n            TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth\n            FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)\n            \n            if (~c_sample_ind).sum() == 0:\n                FP = 0\n            else:\n                nonc_probas = vprobas[~c_sample_ind,:]\n                class_pred = nonc_probas[:,c]\n                max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]\n                FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)\n            \n            losses.append(1 - TP/(TP+FP+FN))\n    \n    if len(losses) == 0: return 0\n    return mean(losses)\n\n# --------------------------- HELPER FUNCTIONS ---------------------------\ndef isnan(x):\n    return x != x\n    \n    \ndef mean(l, ignore_nan=False, empty=0):\n    \"\"\"\n    nanmean compatible with generators.\n    \"\"\"\n    l = iter(l)\n    if ignore_nan:\n        l = ifilterfalse(isnan, l)\n    try:\n        n = 1\n        acc = next(l)\n    except StopIteration:\n        if empty == 'raise':\n            raise ValueError('Empty mean')\n        return empty\n    for n, v in enumerate(l, 2):\n        acc += v\n    if n == 1:\n        return acc\n    return acc / n\n"
  },
  {
    "path": "network/ptBEV.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport numba as nb\nimport multiprocessing\nimport torch_scatter\n\nclass ptBEVnet(nn.Module):\n    \n    def __init__(self, BEV_net, grid_size, pt_model = 'pointnet', fea_dim = 3, pt_pooling = 'max', kernal_size = 3,\n                 out_pt_fea_dim = 64, max_pt_per_encode = 64, cluster_num = 4, pt_selection = 'farthest', fea_compre = None):\n        super(ptBEVnet, self).__init__()\n        assert pt_pooling in ['max']\n        assert pt_selection in ['random','farthest']\n        \n        if pt_model == 'pointnet':\n            self.PPmodel = nn.Sequential(\n                nn.BatchNorm1d(fea_dim),\n                \n                nn.Linear(fea_dim, 64),\n                nn.BatchNorm1d(64),\n                nn.ReLU(inplace=True),\n                \n                nn.Linear(64, 128),\n                nn.BatchNorm1d(128),\n                nn.ReLU(inplace=True),\n                \n                nn.Linear(128, 256),\n                nn.BatchNorm1d(256),\n                nn.ReLU(inplace=True),\n                \n                nn.Linear(256, out_pt_fea_dim)\n            )\n        \n        self.pt_model = pt_model\n        self.BEV_model = BEV_net\n        self.pt_pooling = pt_pooling\n        self.max_pt = max_pt_per_encode\n        self.pt_selection = pt_selection\n        self.fea_compre = fea_compre\n        self.grid_size = grid_size\n        \n        # NN stuff\n        if kernal_size != 1:\n            if self.pt_pooling == 'max':\n                self.local_pool_op = torch.nn.MaxPool2d(kernal_size, stride=1, padding=(kernal_size-1)//2, dilation=1)\n            else: raise NotImplementedError\n        else: self.local_pool_op = None\n        \n        # parametric pooling        \n        if self.pt_pooling == 'max':\n            self.pool_dim = out_pt_fea_dim\n        \n        # point feature compression\n        if self.fea_compre is not None:\n            self.fea_compression = nn.Sequential(\n                    nn.Linear(self.pool_dim, self.fea_compre),\n                    nn.ReLU())\n            self.pt_fea_dim = self.fea_compre\n        else:\n            self.pt_fea_dim = self.pool_dim\n        \n    def forward(self, pt_fea, xy_ind, voxel_fea=None):\n        cur_dev = pt_fea[0].get_device()\n        \n        # concate everything\n        cat_pt_ind = []\n        for i_batch in range(len(xy_ind)):\n            cat_pt_ind.append(F.pad(xy_ind[i_batch],(1,0),'constant',value = i_batch))\n\n        cat_pt_fea = torch.cat(pt_fea,dim = 0)\n        cat_pt_ind = torch.cat(cat_pt_ind,dim = 0)\n        pt_num = cat_pt_ind.shape[0]\n\n        # shuffle the data\n        shuffled_ind = torch.randperm(pt_num,device = cur_dev)\n        cat_pt_fea = cat_pt_fea[shuffled_ind,:]\n        cat_pt_ind = cat_pt_ind[shuffled_ind,:]\n        \n        # unique xy grid index\n        unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind,return_inverse=True, return_counts=True, dim=0)\n        unq = unq.type(torch.int64)\n        \n        # subsample pts\n        if self.pt_selection == 'random':\n            grp_ind = grp_range_torch(unq_cnt,cur_dev)[torch.argsort(torch.argsort(unq_inv))]\n            remain_ind = grp_ind < self.max_pt\n        elif self.pt_selection == 'farthest':\n            unq_ind = np.split(np.argsort(unq_inv.detach().cpu().numpy()), np.cumsum(unq_cnt.detach().cpu().numpy()[:-1]))\n            remain_ind = np.zeros((pt_num,),dtype = np.bool)\n            np_cat_fea = cat_pt_fea.detach().cpu().numpy()[:,:3]\n            pool_in = []\n            for i_inds in unq_ind:\n                if len(i_inds) > self.max_pt:\n                    pool_in.append((np_cat_fea[i_inds,:],self.max_pt))\n            if len(pool_in) > 0:\n                pool = multiprocessing.Pool(multiprocessing.cpu_count())\n                FPS_results = pool.starmap(parallel_FPS, pool_in)\n                pool.close()\n                pool.join()\n            count = 0\n            for i_inds in unq_ind:\n                if len(i_inds) <= self.max_pt:\n                    remain_ind[i_inds] = True\n                else:\n                    remain_ind[i_inds[FPS_results[count]]] = True\n                    count += 1\n            \n        cat_pt_fea = cat_pt_fea[remain_ind,:]\n        cat_pt_ind = cat_pt_ind[remain_ind,:]\n        unq_inv = unq_inv[remain_ind]\n        unq_cnt = torch.clamp(unq_cnt,max=self.max_pt)\n        \n        # process feature\n        if self.pt_model == 'pointnet':\n            processed_cat_pt_fea = self.PPmodel(cat_pt_fea)\n        \n        if self.pt_pooling == 'max':\n            pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]\n        else: raise NotImplementedError\n        \n        if self.fea_compre:\n            processed_pooled_data = self.fea_compression(pooled_data)\n        else:\n            processed_pooled_data = pooled_data\n        \n        # stuff pooled data into 4D tensor\n        out_data_dim = [len(pt_fea),self.grid_size[0],self.grid_size[1],self.pt_fea_dim]\n        out_data = torch.zeros(out_data_dim, dtype=torch.float32).to(cur_dev)\n        out_data[unq[:,0],unq[:,1],unq[:,2],:] = processed_pooled_data\n        out_data = out_data.permute(0,3,1,2)\n        if self.local_pool_op != None:\n            out_data = self.local_pool_op(out_data)\n        if voxel_fea is not None:\n            out_data = torch.cat((out_data, voxel_fea), 1)\n        \n        # run through network\n        sem_prediction, center, offset = self.BEV_model(out_data)\n       \n        return sem_prediction, center, offset\n    \ndef grp_range_torch(a,dev):\n    idx = torch.cumsum(a,0)\n    id_arr = torch.ones(idx[-1],dtype = torch.int64,device=dev)\n    id_arr[0] = 0\n    id_arr[idx[:-1]] = -a[:-1]+1\n    return torch.cumsum(id_arr,0)\n\ndef parallel_FPS(np_cat_fea,K):\n    return  nb_greedy_FPS(np_cat_fea,K)\n\n@nb.jit('b1[:](f4[:,:],i4)',nopython=True,cache=True)\ndef nb_greedy_FPS(xyz,K):\n    start_element = 0\n    sample_num = xyz.shape[0]\n    sum_vec = np.zeros((sample_num,1),dtype = np.float32)\n    xyz_sq = xyz**2\n    for j in range(sample_num):\n        sum_vec[j,0] = np.sum(xyz_sq[j,:])\n    pairwise_distance = sum_vec + np.transpose(sum_vec) - 2*np.dot(xyz, np.transpose(xyz))\n    \n    candidates_ind = np.zeros((sample_num,),dtype = np.bool_)\n    candidates_ind[start_element] = True\n    remain_ind = np.ones((sample_num,),dtype = np.bool_)\n    remain_ind[start_element] = False\n    all_ind = np.arange(sample_num)\n    \n    for i in range(1,K):\n        if i == 1:\n            min_remain_pt_dis = pairwise_distance[:,start_element]\n            min_remain_pt_dis = min_remain_pt_dis[remain_ind]\n        else:\n            cur_dis = pairwise_distance[remain_ind,:]\n            cur_dis = cur_dis[:,candidates_ind]\n            min_remain_pt_dis = np.zeros((cur_dis.shape[0],),dtype = np.float32)\n            for j in range(cur_dis.shape[0]):\n                min_remain_pt_dis[j] = np.min(cur_dis[j,:])\n        next_ind_in_remain = np.argmax(min_remain_pt_dis)\n        next_ind = all_ind[remain_ind][next_ind_in_remain]\n        candidates_ind[next_ind] = True\n        remain_ind[next_ind] = False\n        \n    return candidates_ind"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\ntorch>=1.7.0\ntqdm\npyyaml\nnumba>=0.39.0\ntorch_scatter>=1.3.1\nCython\nscipy\ndropblock"
  },
  {
    "path": "semantic-kitti.yaml",
    "content": "# This file is covered by the LICENSE file in the root of this project.\nlabels: \n  0 : \"unlabeled\"\n  1 : \"outlier\"\n  10: \"car\"\n  11: \"bicycle\"\n  13: \"bus\"\n  15: \"motorcycle\"\n  16: \"on-rails\"\n  18: \"truck\"\n  20: \"other-vehicle\"\n  30: \"person\"\n  31: \"bicyclist\"\n  32: \"motorcyclist\"\n  40: \"road\"\n  44: \"parking\"\n  48: \"sidewalk\"\n  49: \"other-ground\"\n  50: \"building\"\n  51: \"fence\"\n  52: \"other-structure\"\n  60: \"lane-marking\"\n  70: \"vegetation\"\n  71: \"trunk\"\n  72: \"terrain\"\n  80: \"pole\"\n  81: \"traffic-sign\"\n  99: \"other-object\"\n  252: \"moving-car\"\n  253: \"moving-bicyclist\"\n  254: \"moving-person\"\n  255: \"moving-motorcyclist\"\n  256: \"moving-on-rails\"\n  257: \"moving-bus\"\n  258: \"moving-truck\"\n  259: \"moving-other-vehicle\"\ncolor_map: # bgr\n  0 : [0, 0, 0]\n  1 : [0, 0, 255]\n  10: [245, 150, 100]\n  11: [245, 230, 100]\n  13: [250, 80, 100]\n  15: [150, 60, 30]\n  16: [255, 0, 0]\n  18: [180, 30, 80]\n  20: [255, 0, 0]\n  30: [30, 30, 255]\n  31: [200, 40, 255]\n  32: [90, 30, 150]\n  40: [255, 0, 255]\n  44: [255, 150, 255]\n  48: [75, 0, 75]\n  49: [75, 0, 175]\n  50: [0, 200, 255]\n  51: [50, 120, 255]\n  52: [0, 150, 255]\n  60: [170, 255, 150]\n  70: [0, 175, 0]\n  71: [0, 60, 135]\n  72: [80, 240, 150]\n  80: [150, 240, 255]\n  81: [0, 0, 255]\n  99: [255, 255, 50]\n  252: [245, 150, 100]\n  256: [255, 0, 0]\n  253: [200, 40, 255]\n  254: [30, 30, 255]\n  255: [90, 30, 150]\n  257: [250, 80, 100]\n  258: [180, 30, 80]\n  259: [255, 0, 0]\ncontent: # as a ratio with the total number of points\n  0: 0.018889854628292943\n  1: 0.0002937197336781505\n  10: 0.040818519255974316\n  11: 0.00016609538710764618\n  13: 2.7879693665067774e-05\n  15: 0.00039838616015114444\n  16: 0.0\n  18: 0.0020633612104619787\n  20: 0.0016218197275284021\n  30: 0.00017698551338515307\n  31: 1.1065903904919655e-08\n  32: 5.532951952459828e-09\n  40: 0.1987493871255525\n  44: 0.014717169549888214\n  48: 0.14392298360372\n  49: 0.0039048553037472045\n  50: 0.1326861944777486\n  51: 0.0723592229456223\n  52: 0.002395131480328884\n  60: 4.7084144280367186e-05\n  70: 0.26681502148037506\n  71: 0.006035012012626033\n  72: 0.07814222006271769\n  80: 0.002855498193863172\n  81: 0.0006155958086189918\n  99: 0.009923127583046915\n  252: 0.001789309418528068\n  253: 0.00012709999297008662\n  254: 0.00016059776092534436\n  255: 3.745553104802113e-05\n  256: 0.0\n  257: 0.00011351574470342043\n  258: 0.00010157861367183268\n  259: 4.3840131989471124e-05\n# classes that are indistinguishable from single scan or inconsistent in\n# ground truth are mapped to their closest equivalent\nlearning_map:\n  0 : 0     # \"unlabeled\"\n  1 : 0     # \"outlier\" mapped to \"unlabeled\" --------------------------mapped\n  10: 1     # \"car\"\n  11: 2     # \"bicycle\"\n  13: 5     # \"bus\" mapped to \"other-vehicle\" --------------------------mapped\n  15: 3     # \"motorcycle\"\n  16: 5     # \"on-rails\" mapped to \"other-vehicle\" ---------------------mapped\n  18: 4     # \"truck\"\n  20: 5     # \"other-vehicle\"\n  30: 6     # \"person\"\n  31: 7     # \"bicyclist\"\n  32: 8     # \"motorcyclist\"\n  40: 9     # \"road\"\n  44: 10    # \"parking\"\n  48: 11    # \"sidewalk\"\n  49: 12    # \"other-ground\"\n  50: 13    # \"building\"\n  51: 14    # \"fence\"\n  52: 0     # \"other-structure\" mapped to \"unlabeled\" ------------------mapped\n  60: 9     # \"lane-marking\" to \"road\" ---------------------------------mapped\n  70: 15    # \"vegetation\"\n  71: 16    # \"trunk\"\n  72: 17    # \"terrain\"\n  80: 18    # \"pole\"\n  81: 19    # \"traffic-sign\"\n  99: 0     # \"other-object\" to \"unlabeled\" ----------------------------mapped\n  252: 1    # \"moving-car\" to \"car\" ------------------------------------mapped\n  253: 7    # \"moving-bicyclist\" to \"bicyclist\" ------------------------mapped\n  254: 6    # \"moving-person\" to \"person\" ------------------------------mapped\n  255: 8    # \"moving-motorcyclist\" to \"motorcyclist\" ------------------mapped\n  256: 5    # \"moving-on-rails\" mapped to \"other-vehicle\" --------------mapped\n  257: 5    # \"moving-bus\" mapped to \"other-vehicle\" -------------------mapped\n  258: 4    # \"moving-truck\" to \"truck\" --------------------------------mapped\n  259: 5    # \"moving-other\"-vehicle to \"other-vehicle\" ----------------mapped\nlearning_map_inv: # inverse of previous map\n  0: 0      # \"unlabeled\", and others ignored\n  1: 10     # \"car\"\n  2: 11     # \"bicycle\"\n  3: 15     # \"motorcycle\"\n  4: 18     # \"truck\"\n  5: 20     # \"other-vehicle\"\n  6: 30     # \"person\"\n  7: 31     # \"bicyclist\"\n  8: 32     # \"motorcyclist\"\n  9: 40     # \"road\"\n  10: 44    # \"parking\"\n  11: 48    # \"sidewalk\"\n  12: 49    # \"other-ground\"\n  13: 50    # \"building\"\n  14: 51    # \"fence\"\n  15: 70    # \"vegetation\"\n  16: 71    # \"trunk\"\n  17: 72    # \"terrain\"\n  18: 80    # \"pole\"\n  19: 81    # \"traffic-sign\"\nlearning_ignore: # Ignore classes\n  0: True      # \"unlabeled\", and others ignored\n  1: False     # \"car\"\n  2: False     # \"bicycle\"\n  3: False     # \"motorcycle\"\n  4: False     # \"truck\"\n  5: False     # \"other-vehicle\"\n  6: False     # \"person\"\n  7: False     # \"bicyclist\"\n  8: False     # \"motorcyclist\"\n  9: False     # \"road\"\n  10: False    # \"parking\"\n  11: False    # \"sidewalk\"\n  12: False    # \"other-ground\"\n  13: False    # \"building\"\n  14: False    # \"fence\"\n  15: False    # \"vegetation\"\n  16: False    # \"trunk\"\n  17: False    # \"terrain\"\n  18: False    # \"pole\"\n  19: False    # \"traffic-sign\"\nthing_class: # thing class in panoptic segmentation\n  0: False      # \"unlabeled\", and others ignored\n  1: True     # \"car\"\n  2: True     # \"bicycle\"\n  3: True     # \"motorcycle\"\n  4: True     # \"truck\"\n  5: True     # \"other-vehicle\"\n  6: True     # \"person\"\n  7: True     # \"bicyclist\"\n  8: True     # \"motorcyclist\"\n  9: False     # \"road\"\n  10: False    # \"parking\"\n  11: False    # \"sidewalk\"\n  12: False    # \"other-ground\"\n  13: False    # \"building\"\n  14: False    # \"fence\"\n  15: False    # \"vegetation\"\n  16: False    # \"trunk\"\n  17: False    # \"terrain\"\n  18: False    # \"pole\"\n  19: False    # \"traffic-sign\"  \nsplit: # sequence numbers\n  train:\n    - 0\n    - 1\n    - 2\n    - 3\n    - 4\n    - 5\n    - 6\n    - 7\n    - 9\n    - 10\n  valid:\n    - 8\n  test:\n    - 11\n    - 12\n    - 13\n    - 14\n    - 15\n    - 16\n    - 17\n    - 18\n    - 19\n    - 20\n    - 21\n"
  },
  {
    "path": "test_pretrain.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport os\nimport time\nimport argparse\nimport sys\nimport yaml\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom tqdm import tqdm\nimport errno\n\nfrom network.BEV_Unet import BEV_Unet\nfrom network.ptBEV import ptBEVnet\nfrom dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset,collate_fn_BEV_test\nfrom network.instance_post_processing import get_panoptic_segmentation\nfrom utils.eval_pq import PanopticEval\nfrom utils.configs import merge_configs\n#ignore weird np warning\nimport warnings\nwarnings.filterwarnings(\"ignore\")\n\ndef SemKITTI2train(label):\n    if isinstance(label, list):\n        return [SemKITTI2train_single(a) for a in label]\n    else:\n        return SemKITTI2train_single(label)\n\ndef SemKITTI2train_single(label):\n    return label - 1 # uint8 trick\n\ndef main(args):\n    data_path = args['dataset']['path']\n    test_batch_size = args['model']['test_batch_size']\n    pretrained_model = args['model']['pretrained_model']\n    output_path = args['dataset']['output_path']\n    compression_model = args['dataset']['grid_size'][2]\n    grid_size = args['dataset']['grid_size']\n    visibility = args['model']['visibility']\n    pytorch_device = torch.device('cuda:0')\n    if args['model']['polar']:\n        fea_dim = 9\n        circular_padding = True\n    else:\n        fea_dim = 7\n        circular_padding = False\n\n    # prepare miou fun\n    unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1\n    unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]\n\n    # prepare model\n    my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)\n    my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size =  grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,\n                            out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)\n    if os.path.exists(pretrained_model):\n        my_model.load_state_dict(torch.load(pretrained_model))\n    pytorch_total_params = sum(p.numel() for p in my_model.parameters())\n    print('params: ',pytorch_total_params)\n    my_model.to(pytorch_device)\n    my_model.eval()\n\n    # prepare dataset\n    test_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'test', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])\n    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])\n    if args['model']['polar']:\n        test_dataset=spherical_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)\n        val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)\n    else:\n        test_dataset=voxel_dataset(test_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, return_test= True)\n        val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)\n    test_dataset_loader = torch.utils.data.DataLoader(dataset = test_dataset,\n                                                    batch_size = test_batch_size,\n                                                    collate_fn = collate_fn_BEV_test,\n                                                    shuffle = False,\n                                                    num_workers = 4)\n    val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,\n                                                    batch_size = test_batch_size,\n                                                    collate_fn = collate_fn_BEV,\n                                                    shuffle = False,\n                                                    num_workers = 4)\n\n    # validation\n    print('*'*80)\n    print('Test network performance on validation split')\n    print('*'*80)\n    pbar = tqdm(total=len(val_dataset_loader))\n    time_list = []\n    pp_time_list = []\n    evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)\n    with torch.no_grad():\n        for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):\n            val_vox_fea_ten = val_vox_fea.to(pytorch_device)\n            val_vox_label = SemKITTI2train(val_vox_label)\n            val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]\n            val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]\n            val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)\n            val_gt_center_tensor = val_gt_center.to(pytorch_device)\n            val_gt_offset_tensor = val_gt_offset.to(pytorch_device)\n\n            torch.cuda.synchronize()\n            start_time = time.time()\n            if visibility:            \n                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)\n            else:\n                predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)\n            torch.cuda.synchronize()\n            time_list.append(time.time()-start_time)\n\n            for count,i_val_grid in enumerate(val_grid):\n                # get foreground_mask\n                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)\n                for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True\n                # post processing\n                torch.cuda.synchronize()\n                start_time = time.time()\n                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),val_pt_dataset.thing_list,\\\n                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\\\n                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)\n                torch.cuda.synchronize()\n                pp_time_list.append(time.time()-start_time)\n                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)\n                panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]\n\n                evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))\n            del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points\n            pbar.update(1)\n    \n    class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()\n    miou,ious = evaluator.getSemIoU()\n    print('Validation per class PQ, SQ, RQ and IoU: ')\n    for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):\n        print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))\n    pbar.close()\n    print('Current val PQ is %.3f' %\n        (class_PQ*100))               \n    print('Current val miou is %.3f'%\n        (miou*100))\n    print('Inference time per %d is %.4f seconds\\n, postprocessing time is %.4f seconds per scan' %\n        (test_batch_size,np.mean(time_list),np.mean(pp_time_list)))\n    \n    # test\n    print('*'*80)\n    print('Generate predictions for test split')\n    print('*'*80)\n    pbar = tqdm(total=len(test_dataset_loader))\n    with torch.no_grad():\n        for i_iter_test,(test_vox_fea,_,_,_,test_grid,_,_,test_pt_fea,test_index) in enumerate(test_dataset_loader):\n            # predict\n            test_vox_fea_ten = test_vox_fea.to(pytorch_device)\n            test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in test_pt_fea]\n            test_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in test_grid]\n\n            if visibility:\n                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten,test_vox_fea_ten)\n            else:\n                predict_labels,center,offset = my_model(test_pt_fea_ten,test_grid_ten)\n            # write to label file\n            for count,i_test_grid in enumerate(test_grid):\n                # get foreground_mask\n                for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)\n                for_mask[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]] = True\n                # post processing\n                panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),test_pt_dataset.thing_list,\\\n                                                                                          threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\\\n                                                                                          top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)\n                panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.uint32)\n                panoptic = panoptic_labels[0,test_grid[count][:,0],test_grid[count][:,1],test_grid[count][:,2]]\n                save_dir = test_pt_dataset.im_idx[test_index[count]]\n                _,dir2 = save_dir.split('/sequences/',1)\n                new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne','predictions')[:-3]+'label'\n                if not os.path.exists(os.path.dirname(new_save_dir)):\n                    try:\n                        os.makedirs(os.path.dirname(new_save_dir))\n                    except OSError as exc:\n                        if exc.errno != errno.EEXIST:\n                            raise\n                panoptic.tofile(new_save_dir)\n            del test_pt_fea_ten,test_grid_ten,test_pt_fea,predict_labels,center,offset\n            pbar.update(1)\n    pbar.close()\n    print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path)\n    print('Remapping script can be found in semantic-kitti-api.')\n\nif __name__ == '__main__':\n    # Testing settings\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('-d', '--data_dir', default='data')\n    parser.add_argument('-p', '--pretrained_model', default='pretrained_weight/Panoptic_SemKITTI_PolarNet.pt')\n    parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')\n    \n    args = parser.parse_args()\n    with open(args.configs, 'r') as s:\n        new_args = yaml.safe_load(s)\n    args = merge_configs(args,new_args)\n\n    print(' '.join(sys.argv))\n    print(args)\n    main(args)"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport os\nimport argparse\nimport sys\nimport numpy as np\nimport yaml\nimport torch\nimport torch.optim as optim\nfrom tqdm import tqdm\n\nfrom network.BEV_Unet import BEV_Unet\nfrom network.ptBEV import ptBEVnet\nfrom dataloader.dataset import collate_fn_BEV,SemKITTI,SemKITTI_label_name,spherical_dataset,voxel_dataset\nfrom network.instance_post_processing import get_panoptic_segmentation\nfrom network.loss import panoptic_loss\nfrom utils.eval_pq import PanopticEval\nfrom utils.configs import merge_configs\n#ignore weird np warning\nimport warnings\nwarnings.filterwarnings(\"ignore\")\n\ndef SemKITTI2train(label):\n    if isinstance(label, list):\n        return [SemKITTI2train_single(a) for a in label]\n    else:\n        return SemKITTI2train_single(label)\n\ndef SemKITTI2train_single(label):\n    return label - 1 # uint8 trick\n\ndef load_pretrained_model(model,pretrained_model):\n    model_dict = model.state_dict()\n    pretrained_model = {k: v for k, v in pretrained_model.items() if k in model_dict}\n    model_dict.update(pretrained_model) \n    model.load_state_dict(model_dict)\n    return model\n\ndef main(args):\n    data_path = args['dataset']['path']\n    train_batch_size = args['model']['train_batch_size']\n    val_batch_size = args['model']['val_batch_size']\n    check_iter = args['model']['check_iter']\n    model_save_path = args['model']['model_save_path']\n    pretrained_model = args['model']['pretrained_model']\n    compression_model = args['dataset']['grid_size'][2]\n    grid_size = args['dataset']['grid_size']\n    visibility = args['model']['visibility']\n    pytorch_device = torch.device('cuda:0')\n    if args['model']['polar']:\n        fea_dim = 9\n        circular_padding = True\n    else:\n        fea_dim = 7\n        circular_padding = False\n\n    #prepare miou fun\n    unique_label=np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1\n    unique_label_str=[SemKITTI_label_name[x] for x in unique_label+1]\n\n    #prepare model\n    my_BEV_model=BEV_Unet(n_class=len(unique_label), n_height = compression_model, input_batch_norm = True, dropout = 0.5, circular_padding = circular_padding, use_vis_fea=visibility)\n    my_model = ptBEVnet(my_BEV_model, pt_model = 'pointnet', grid_size =  grid_size, fea_dim = fea_dim, max_pt_per_encode = 256,\n                            out_pt_fea_dim = 512, kernal_size = 1, pt_selection = 'random', fea_compre = compression_model)\n    if os.path.exists(model_save_path):\n        my_model = load_pretrained_model(my_model,torch.load(model_save_path))\n    elif os.path.exists(pretrained_model):\n        my_model = load_pretrained_model(my_model,torch.load(pretrained_model))\n    my_model.to(pytorch_device)\n\n    optimizer = optim.Adam(my_model.parameters())\n    loss_fn = panoptic_loss(center_loss_weight = args['model']['center_loss_weight'], offset_loss_weight = args['model']['offset_loss_weight'],\\\n                            center_loss = args['model']['center_loss'], offset_loss=args['model']['offset_loss'])\n\n    #prepare dataset\n    train_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'train', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])\n    val_pt_dataset = SemKITTI(data_path + '/sequences/', imageset = 'val', return_ref = True, instance_pkl_path=args['dataset']['instance_pkl_path'])\n    if args['model']['polar']:\n        train_dataset=spherical_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0, use_aug = True)\n        val_dataset=spherical_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)\n    else:\n        train_dataset=voxel_dataset(train_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0,use_aug = True)\n        val_dataset=voxel_dataset(val_pt_dataset, args['dataset'], grid_size = grid_size, ignore_label = 0)\n    train_dataset_loader = torch.utils.data.DataLoader(dataset = train_dataset,\n                                                    batch_size = train_batch_size,\n                                                    collate_fn = collate_fn_BEV,\n                                                    shuffle = True,\n                                                    num_workers = 4)\n    val_dataset_loader = torch.utils.data.DataLoader(dataset = val_dataset,\n                                                    batch_size = val_batch_size,\n                                                    collate_fn = collate_fn_BEV,\n                                                    shuffle = False,\n                                                    num_workers = 4)\n\n    # training\n    epoch=0\n    best_val_PQ=0\n    start_training=False\n    my_model.train()\n    global_iter = 0\n    exce_counter = 0\n    evaluator = PanopticEval(len(unique_label)+1, None, [0], min_points=50)\n\n    while epoch < args['model']['max_epoch']:\n        pbar = tqdm(total=len(train_dataset_loader))\n        for i_iter,(train_vox_fea,train_label_tensor,train_gt_center,train_gt_offset,train_grid,_,_,train_pt_fea) in enumerate(train_dataset_loader):\n            # validation\n            if global_iter % check_iter == 0:\n                my_model.eval()\n                evaluator.reset()\n                with torch.no_grad():\n                    for i_iter_val,(val_vox_fea,val_vox_label,val_gt_center,val_gt_offset,val_grid,val_pt_labels,val_pt_ints,val_pt_fea) in enumerate(val_dataset_loader):\n                        val_vox_fea_ten = val_vox_fea.to(pytorch_device)\n                        val_vox_label = SemKITTI2train(val_vox_label)\n                        val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in val_pt_fea]\n                        val_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in val_grid]\n                        val_label_tensor=val_vox_label.type(torch.LongTensor).to(pytorch_device)\n                        val_gt_center_tensor = val_gt_center.to(pytorch_device)\n                        val_gt_offset_tensor = val_gt_offset.to(pytorch_device)\n\n                        if visibility:\n                            predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten, val_vox_fea_ten)\n                        else:\n                            predict_labels,center,offset = my_model(val_pt_fea_ten, val_grid_ten)\n\n                        for count,i_val_grid in enumerate(val_grid):\n                            # get foreground_mask\n                            for_mask = torch.zeros(1,grid_size[0],grid_size[1],grid_size[2],dtype=torch.bool).to(pytorch_device)\n                            for_mask[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]] = True\n                            # post processing\n                            panoptic_labels,center_points = get_panoptic_segmentation(torch.unsqueeze(predict_labels[count], 0),torch.unsqueeze(center[count], 0),torch.unsqueeze(offset[count], 0),\\\n                                                                                      val_pt_dataset.thing_list, threshold=args['model']['post_proc']['threshold'], nms_kernel=args['model']['post_proc']['nms_kernel'],\\\n                                                                                      top_k=args['model']['post_proc']['top_k'], polar=circular_padding,foreground_mask=for_mask)\n                            panoptic_labels = panoptic_labels.cpu().detach().numpy().astype(np.int32)\n                            panoptic = panoptic_labels[0,val_grid[count][:,0],val_grid[count][:,1],val_grid[count][:,2]]\n                            evaluator.addBatch(panoptic & 0xFFFF,panoptic,np.squeeze(val_pt_labels[count]),np.squeeze(val_pt_ints[count]))\n                        del val_vox_label,val_pt_fea_ten,val_label_tensor,val_grid_ten,val_gt_center,val_gt_center_tensor,val_gt_offset,val_gt_offset_tensor,predict_labels,center,offset,panoptic_labels,center_points\n                my_model.train()\n                class_PQ, class_SQ, class_RQ, class_all_PQ, class_all_SQ, class_all_RQ = evaluator.getPQ()\n                miou,ious = evaluator.getSemIoU()\n                print('Validation per class PQ, SQ, RQ and IoU: ')\n                for class_name, class_pq, class_sq, class_rq, class_iou in zip(unique_label_str,class_all_PQ[1:],class_all_SQ[1:],class_all_RQ[1:],ious[1:]):\n                    print('%15s : %6.2f%%  %6.2f%%  %6.2f%%  %6.2f%%' % (class_name, class_pq*100, class_sq*100, class_rq*100, class_iou*100))                                  \n                # save model if performance is improved\n                if best_val_PQ<class_PQ:\n                    best_val_PQ=class_PQ\n                    torch.save(my_model.state_dict(), model_save_path)\n                print('Current val PQ is %.3f while the best val PQ is %.3f' %\n                    (class_PQ*100,best_val_PQ*100))               \n                print('Current val miou is %.3f'%\n                    (miou*100))\n\n                if start_training:\n                    sem_l ,hm_l, os_l = np.mean(loss_fn.lost_dict['semantic_loss']), np.mean(loss_fn.lost_dict['heatmap_loss']), np.mean(loss_fn.lost_dict['offset_loss'])\n                    print('epoch %d iter %5d, loss: %.3f, semantic loss: %.3f, heatmap loss: %.3f, offset loss: %.3f\\n' %\n                        (epoch, i_iter, sem_l+hm_l+os_l, sem_l, hm_l, os_l))\n                print('%d exceptions encountered during last training\\n' %\n                    exce_counter)\n                exce_counter = 0\n                loss_fn.reset_loss_dict()\n\n            # training\n            try:\n                train_vox_fea_ten = train_vox_fea.to(pytorch_device)\n                train_label_tensor = SemKITTI2train(train_label_tensor)\n                train_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in train_pt_fea]\n                train_grid_ten = [torch.from_numpy(i[:,:2]).to(pytorch_device) for i in train_grid]\n                train_label_tensor=train_label_tensor.type(torch.LongTensor).to(pytorch_device)\n                train_gt_center_tensor = train_gt_center.to(pytorch_device)\n                train_gt_offset_tensor = train_gt_offset.to(pytorch_device)\n\n                if args['model']['enable_SAP'] and epoch>=args['model']['SAP']['start_epoch']:\n                    for fea in train_pt_fea_ten:\n                        fea.requires_grad_()\n        \n                # forward\n                if visibility:\n                    sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)\n                else:\n                    sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)\n                # loss\n                loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)\n                \n                \n                # self adversarial pruning\n                if args['model']['enable_SAP'] and epoch>=args['model']['SAP']['start_epoch']:\n                    loss.backward()\n                    for i,fea in enumerate(train_pt_fea_ten):\n                        fea_grad = torch.norm(fea.grad,dim=1)\n                        top_k_grad, _ = torch.topk(fea_grad, int(args['model']['SAP']['rate']*fea_grad.shape[0]))\n                        # delete high influential points\n                        train_pt_fea_ten[i] = train_pt_fea_ten[i][fea_grad < top_k_grad[-1]]\n                        train_grid_ten[i] = train_grid_ten[i][fea_grad < top_k_grad[-1]]\n                    optimizer.zero_grad()\n\n                    # second pass\n                    # forward\n                    if visibility:\n                        sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten,train_vox_fea_ten)\n                    else:\n                        sem_prediction,center,offset = my_model(train_pt_fea_ten,train_grid_ten)\n                    # loss\n                    loss = loss_fn(sem_prediction,center,offset,train_label_tensor,train_gt_center_tensor,train_gt_offset_tensor)\n                    \n                # backward + optimize\n                loss.backward()\n                optimizer.step()\n            except Exception as error: \n                if exce_counter == 0:\n                    print(error)\n                exce_counter += 1\n            \n            # zero the parameter gradients\n            optimizer.zero_grad()\n            pbar.update(1)\n            start_training=True\n            global_iter += 1\n        pbar.close()\n        epoch += 1\n\nif __name__ == '__main__':\n    # Training settings\n    parser = argparse.ArgumentParser(description='')\n    parser.add_argument('-d', '--data_dir', default='data')\n    parser.add_argument('-p', '--model_save_path', default='./Panoptic_SemKITTI.pt')\n    parser.add_argument('-c', '--configs', default='configs/SemanticKITTI_model/Panoptic-PolarNet.yaml')\n    parser.add_argument('--pretrained_model', default='empty')\n\n    args = parser.parse_args()\n    with open(args.configs, 'r') as s:\n        new_args = yaml.safe_load(s)\n    args = merge_configs(args,new_args)\n\n    print(' '.join(sys.argv))\n    print(args)\n    main(args)"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/configs.py",
    "content": "#!/usr/bin/env python3\nimport yaml\n\ndef merge_configs(cfgs,new_cfgs):\n    if hasattr(cfgs, 'data_dir'):\n        new_cfgs['dataset']['path']=cfgs.data_dir\n    if hasattr(cfgs, 'model_save_path'):\n        new_cfgs['model']['model_save_path']=cfgs.model_save_path\n    if hasattr(cfgs, 'pretrained_model'):\n        new_cfgs['model']['pretrained_model']=cfgs.pretrained_model\n    return new_cfgs"
  },
  {
    "path": "utils/eval_pq.py",
    "content": "#!/usr/bin/env python3\nimport numpy as np\nimport time\n\n\nclass PanopticEval:\n  \"\"\" Panoptic evaluation using numpy\n  \n  authors: Andres Milioto and Jens Behley\n  \"\"\"\n\n  def __init__(self, n_classes, device=None, ignore=None, offset=2**32, min_points=30):\n    self.n_classes = n_classes\n    assert (device == None)\n    self.ignore = np.array(ignore, dtype=np.int64)\n    self.include = np.array([n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64)\n\n    print(\"[PANOPTIC EVAL] IGNORE: \", self.ignore)\n    print(\"[PANOPTIC EVAL] INCLUDE: \", self.include)\n\n    self.reset()\n    self.offset = offset  # largest number of instances in a given scan\n    self.min_points = min_points  # smallest number of points to consider instances in gt\n    self.eps = 1e-15\n\n  def num_classes(self):\n    return self.n_classes\n\n  def reset(self):\n    # general things\n    # iou stuff\n    self.px_iou_conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64)\n    # panoptic stuff\n    self.pan_tp = np.zeros(self.n_classes, dtype=np.int64)\n    self.pan_iou = np.zeros(self.n_classes, dtype=np.double)\n    self.pan_fp = np.zeros(self.n_classes, dtype=np.int64)\n    self.pan_fn = np.zeros(self.n_classes, dtype=np.int64)\n\n  ################################# IoU STUFF ##################################\n  def addBatchSemIoU(self, x_sem, y_sem):\n    # idxs are labels and predictions\n    idxs = np.stack([x_sem, y_sem], axis=0)\n\n    # make confusion matrix (cols = gt, rows = pred)\n    np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1)\n\n  def getSemIoUStats(self):\n    # clone to avoid modifying the real deal\n    conf = self.px_iou_conf_matrix.copy().astype(np.double)\n    # remove fp from confusion on the ignore classes predictions\n    # points that were predicted of another class, but were ignore\n    # (corresponds to zeroing the cols of those classes, since the predictions\n    # go on the rows)\n    conf[:, self.ignore] = 0\n\n    # get the clean stats\n    tp = conf.diagonal()\n    fp = conf.sum(axis=1) - tp\n    fn = conf.sum(axis=0) - tp\n    return tp, fp, fn\n\n  def getSemIoU(self):\n    tp, fp, fn = self.getSemIoUStats()\n    # print(f\"tp={tp}\")\n    # print(f\"fp={fp}\")\n    # print(f\"fn={fn}\")\n    intersection = tp\n    union = tp + fp + fn\n    union = np.maximum(union, self.eps)\n    iou = intersection.astype(np.double) / union.astype(np.double)\n    iou_mean = (intersection[self.include].astype(np.double) / union[self.include].astype(np.double)).mean()\n\n    return iou_mean, iou  # returns \"iou mean\", \"iou per class\" ALL CLASSES\n\n  def getSemAcc(self):\n    tp, fp, fn = self.getSemIoUStats()\n    total_tp = tp.sum()\n    total = tp[self.include].sum() + fp[self.include].sum()\n    total = np.maximum(total, self.eps)\n    acc_mean = total_tp.astype(np.double) / total.astype(np.double)\n\n    return acc_mean  # returns \"acc mean\"\n\n  ################################# IoU STUFF ##################################\n  ##############################################################################\n\n  #############################  Panoptic STUFF ################################\n  def addBatchPanoptic(self, x_sem_row, x_inst_row, y_sem_row, y_inst_row):\n    # make sure instances are not zeros (it messes with my approach)\n    x_inst_row = x_inst_row + 1\n    y_inst_row = y_inst_row + 1\n\n    # only interested in points that are outside the void area (not in excluded classes)\n    for cl in self.ignore:\n      # make a mask for this class\n      gt_not_in_excl_mask = y_sem_row != cl\n      # remove all other points\n      x_sem_row = x_sem_row[gt_not_in_excl_mask]\n      y_sem_row = y_sem_row[gt_not_in_excl_mask]\n      x_inst_row = x_inst_row[gt_not_in_excl_mask]\n      y_inst_row = y_inst_row[gt_not_in_excl_mask]\n\n    # first step is to count intersections > 0.5 IoU for each class (except the ignored ones)\n    for cl in self.include:\n      # print(\"*\"*80)\n      # print(\"CLASS\", cl.item())\n      # get a class mask\n      x_inst_in_cl_mask = x_sem_row == cl\n      y_inst_in_cl_mask = y_sem_row == cl\n\n      # get instance points in class (makes outside stuff 0)\n      x_inst_in_cl = x_inst_row * x_inst_in_cl_mask.astype(np.int64)\n      y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64)\n\n      # generate the areas for each unique instance prediction\n      unique_pred, counts_pred = np.unique(x_inst_in_cl[x_inst_in_cl > 0], return_counts=True)\n      id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}\n      matched_pred = np.array([False] * unique_pred.shape[0])\n      # print(\"Unique predictions:\", unique_pred)\n\n      # generate the areas for each unique instance gt_np\n      unique_gt, counts_gt = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True)\n      id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}\n      matched_gt = np.array([False] * unique_gt.shape[0])\n      # print(\"Unique ground truth:\", unique_gt)\n\n      # generate intersection using offset\n      valid_combos = np.logical_and(x_inst_in_cl > 0, y_inst_in_cl > 0)\n      offset_combo = x_inst_in_cl[valid_combos] + self.offset * y_inst_in_cl[valid_combos]\n      unique_combo, counts_combo = np.unique(offset_combo, return_counts=True)\n\n      # generate an intersection map\n      # count the intersections with over 0.5 IoU as TP\n      gt_labels = unique_combo // self.offset\n      pred_labels = unique_combo % self.offset\n      gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])\n      pred_areas = np.array([counts_pred[id2idx_pred[id]] for id in pred_labels])\n      intersections = counts_combo\n      unions = gt_areas + pred_areas - intersections\n      ious = intersections.astype(np.float) / unions.astype(np.float)\n\n\n      tp_indexes = ious > 0.5\n      self.pan_tp[cl] += np.sum(tp_indexes)\n      self.pan_iou[cl] += np.sum(ious[tp_indexes])\n\n      matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True\n      matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True\n\n      # count the FN\n      self.pan_fn[cl] += np.sum(np.logical_and(counts_gt >= self.min_points, matched_gt == False))\n\n      # count the FP\n      self.pan_fp[cl] += np.sum(np.logical_and(counts_pred >= self.min_points, matched_pred == False))\n\n  def getPQ(self):\n    # first calculate for all classes\n    sq_all = self.pan_iou.astype(np.double) / np.maximum(self.pan_tp.astype(np.double), self.eps)\n    rq_all = self.pan_tp.astype(np.double) / np.maximum(\n        self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double),\n        self.eps)\n    pq_all = sq_all * rq_all\n\n    # then do the REAL mean (no ignored classes)\n    SQ = sq_all[self.include].mean()\n    RQ = rq_all[self.include].mean()\n    PQ = pq_all[self.include].mean()\n\n    return PQ, SQ, RQ, pq_all, sq_all, rq_all\n\n  #############################  Panoptic STUFF ################################\n  ##############################################################################\n\n  def addBatch(self, x_sem, x_inst, y_sem, y_inst):  # x=preds, y=targets\n    ''' IMPORTANT: Inputs must be batched. Either [N,H,W], or [N, P]\n    '''\n    # add to IoU calculation (for checking purposes)\n    self.addBatchSemIoU(x_sem, y_sem)\n\n    # now do the panoptic stuff\n    self.addBatchPanoptic(x_sem, x_inst, y_sem, y_inst)\n\n\nif __name__ == \"__main__\":\n  # generate problem from He paper (https://arxiv.org/pdf/1801.00868.pdf)\n  classes = 5  # ignore, grass, sky, person, dog\n  cl_strings = [\"ignore\", \"grass\", \"sky\", \"person\", \"dog\"]\n  ignore = [0]  # only ignore ignore class\n  min_points = 1  # for this example we care about all points\n\n  # generate ground truth and prediction\n  sem_pred = []\n  inst_pred = []\n  sem_gt = []\n  inst_gt = []\n\n  # some ignore stuff\n  N_ignore = 50\n  sem_pred.extend([0 for i in range(N_ignore)])\n  inst_pred.extend([0 for i in range(N_ignore)])\n  sem_gt.extend([0 for i in range(N_ignore)])\n  inst_gt.extend([0 for i in range(N_ignore)])\n\n  # grass segment\n  N_grass = 50\n  N_grass_pred = 40  # rest is sky\n  sem_pred.extend([1 for i in range(N_grass_pred)])  # grass\n  sem_pred.extend([2 for i in range(N_grass - N_grass_pred)])  # sky\n  inst_pred.extend([0 for i in range(N_grass)])\n  sem_gt.extend([1 for i in range(N_grass)])  # grass\n  inst_gt.extend([0 for i in range(N_grass)])\n\n  # sky segment\n  N_sky = 50\n  N_sky_pred = 40  # rest is grass\n  sem_pred.extend([2 for i in range(N_sky_pred)])  # sky\n  sem_pred.extend([1 for i in range(N_sky - N_sky_pred)])  # grass\n  inst_pred.extend([0 for i in range(N_sky)])  # first instance\n  sem_gt.extend([2 for i in range(N_sky)])  # sky\n  inst_gt.extend([0 for i in range(N_sky)])  # first instance\n\n  # wrong dog as person prediction\n  N_dog = 50\n  N_person = N_dog\n  sem_pred.extend([3 for i in range(N_person)])\n  inst_pred.extend([35 for i in range(N_person)])\n  sem_gt.extend([4 for i in range(N_dog)])\n  inst_gt.extend([22 for i in range(N_dog)])\n\n  # two persons in prediction, but three in gt\n  N_person = 50\n  sem_pred.extend([3 for i in range(6 * N_person)])\n  inst_pred.extend([8 for i in range(4 * N_person)])\n  inst_pred.extend([95 for i in range(2 * N_person)])\n  sem_gt.extend([3 for i in range(6 * N_person)])\n  inst_gt.extend([33 for i in range(3 * N_person)])\n  inst_gt.extend([42 for i in range(N_person)])\n  inst_gt.extend([11 for i in range(2 * N_person)])\n\n  # gt and pred to numpy\n  sem_pred = np.array(sem_pred, dtype=np.int64).reshape(1, -1)\n  inst_pred = np.array(inst_pred, dtype=np.int64).reshape(1, -1)\n  sem_gt = np.array(sem_gt, dtype=np.int64).reshape(1, -1)\n  inst_gt = np.array(inst_gt, dtype=np.int64).reshape(1, -1)\n\n  # evaluator\n  evaluator = PanopticEval(classes, ignore=ignore, min_points=1)\n  evaluator.addBatch(sem_pred, inst_pred, sem_gt, inst_gt)\n  pq, sq, rq, all_pq, all_sq, all_rq = evaluator.getPQ()\n  iou, all_iou = evaluator.getSemIoU()\n\n  # [PANOPTIC EVAL] IGNORE:  [0]\n  # [PANOPTIC EVAL] INCLUDE:  [1 2 3 4]\n  # TOTALS\n  # PQ: 0.47916666666666663\n  # SQ: 0.5520833333333333\n  # RQ: 0.6666666666666666\n  # IoU: 0.5476190476190476\n  # Class ignore \t PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0\n  # Class grass \t PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666\n  # Class sky \t PQ: 0.6666666666666666 SQ: 0.6666666666666666 RQ: 1.0 IoU: 0.6666666666666666\n  # Class person \t PQ: 0.5833333333333333 SQ: 0.875 RQ: 0.6666666666666666 IoU: 0.8571428571428571\n  # Class dog \t PQ: 0.0 SQ: 0.0 RQ: 0.0 IoU: 0.0\n\n  print(\"TOTALS\")\n  print(\"PQ:\", pq.item(), pq.item() == 0.47916666666666663)\n  print(\"SQ:\", sq.item(), sq.item() == 0.5520833333333333)\n  print(\"RQ:\", rq.item(), rq.item() == 0.6666666666666666)\n  print(\"IoU:\", iou.item(), iou.item() == 0.5476190476190476)\n  for i, (pq, sq, rq, iou) in enumerate(zip(all_pq, all_sq, all_rq, all_iou)):\n    print(\"Class\", cl_strings[i], \"\\t\", \"PQ:\", pq.item(), \"SQ:\", sq.item(), \"RQ:\", rq.item(), \"IoU:\", iou.item())"
  },
  {
    "path": "utils/visual.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\nimport cv2\nimport numpy as np\n\ndef flow_to_img(flow, normalize=True):\n    \"\"\"Convert flow to viewable image, using color hue to encode flow vector orientation, and color saturation to\n    encode vector length. This is similar to the OpenCV tutorial on dense optical flow, except that they map vector\n    length to the value plane of the HSV color model, instead of the saturation plane, as we do here.\n    Args:\n        flow: optical flow\n        normalize: Normalize flow to 0..255\n    Returns:\n        img: viewable representation of the dense optical flow in RGB format\n    Ref:\n        https://github.com/philferriere/tfoptflow/blob/33e8a701e34c8ce061f17297d40619afbd459ade/tfoptflow/optflow.py\n    \"\"\"\n    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)\n    flow_magnitude, flow_angle = cv2.cartToPolar(flow[..., 0].astype(np.float32), flow[..., 1].astype(np.float32))\n\n    # A couple times, we've gotten NaNs out of the above...\n    nans = np.isnan(flow_magnitude)\n    if np.any(nans):\n        nans = np.where(nans)\n        flow_magnitude[nans] = 0.\n\n    # Normalize\n    hsv[..., 0] = flow_angle * 180 / np.pi / 2\n    if normalize is True:\n        hsv[..., 1] = cv2.normalize(flow_magnitude, None, 0, 255, cv2.NORM_MINMAX)\n    else:\n        hsv[..., 1] = flow_magnitude\n    hsv[..., 2] = 255\n    img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)\n\n    return img"
  }
]