Full Code of hszhao/PointWeb for AI

master f31fe05616c3 cached
61 files
218.5 KB
61.1k tokens
276 symbols
1 requests
Download .txt
Showing preview only (236K chars total). Download the full file or copy to clipboard to get everything.
Repository: hszhao/PointWeb
Branch: master
Commit: f31fe05616c3
Files: 61
Total size: 218.5 KB

Directory structure:
gitextract_6_2c5ysk/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── s3dis/
│   │   └── s3dis_names.txt
│   └── scannet/
│       └── scannet_names.txt
├── lib/
│   ├── __init__.py
│   ├── pointops/
│   │   ├── __init__.py
│   │   ├── functions/
│   │   │   ├── __init__.py
│   │   │   └── pointops.py
│   │   ├── setup.py
│   │   └── src/
│   │       ├── __init__.py
│   │       ├── ballquery/
│   │       │   ├── ballquery_cuda.cpp
│   │       │   ├── ballquery_cuda_kernel.cu
│   │       │   └── ballquery_cuda_kernel.h
│   │       ├── cuda_utils.h
│   │       ├── featuredistribute/
│   │       │   ├── featuredistribute_cuda.cpp
│   │       │   ├── featuredistribute_cuda_kernel.cu
│   │       │   └── featuredistribute_cuda_kernel.h
│   │       ├── grouping/
│   │       │   ├── grouping_cuda.cpp
│   │       │   ├── grouping_cuda_kernel.cu
│   │       │   └── grouping_cuda_kernel.h
│   │       ├── grouping_int/
│   │       │   ├── grouping_int_cuda.cpp
│   │       │   ├── grouping_int_cuda_kernel.cu
│   │       │   └── grouping_int_cuda_kernel.h
│   │       ├── interpolation/
│   │       │   ├── interpolation_cuda.cpp
│   │       │   ├── interpolation_cuda_kernel.cu
│   │       │   └── interpolation_cuda_kernel.h
│   │       ├── knnquery/
│   │       │   ├── __init__.py
│   │       │   ├── knnquery_cuda.cpp
│   │       │   ├── knnquery_cuda_kernel.cu
│   │       │   └── knnquery_cuda_kernel.h
│   │       ├── labelstat/
│   │       │   ├── labelstat_cuda.cpp
│   │       │   ├── labelstat_cuda_kernel.cu
│   │       │   └── labelstat_cuda_kernel.h
│   │       ├── pointops_api.cpp
│   │       └── sampling/
│   │           ├── sampling_cuda.cpp
│   │           ├── sampling_cuda_kernel.cu
│   │           └── sampling_cuda_kernel.h
│   └── sync_bn/
│       ├── __init__.py
│       ├── batchnorm.py
│       ├── comm.py
│       ├── replicate.py
│       └── unittest.py
├── model/
│   ├── __init__.py
│   ├── pointnet/
│   │   └── pointnet.py
│   ├── pointnet2/
│   │   ├── pointnet2_modules.py
│   │   └── pointnet2_seg.py
│   └── pointweb/
│       ├── pointweb_module.py
│       └── pointweb_seg.py
├── tool/
│   ├── test.sh
│   ├── test_s3dis.py
│   ├── test_scannet.py
│   ├── train.py
│   └── train.sh
└── util/
    ├── config.py
    ├── dataset.py
    ├── pt_util.py
    ├── s3dis.py
    ├── scannet.py
    ├── transform.py
    └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
## General

# Compiled Object files
*.slo
*.lo
*.o
*.cuo

# Compiled Dynamic libraries
*.so
*.dylib

# Compiled Static libraries
*.lai
*.la
*.a

# Compiled protocol buffers
*.pb.h
*.pb.cc
*_pb2.py

# Compiled python
*.pyc

# Compiled MATLAB
*.mex*

# IPython notebook checkpoints
.ipynb_checkpoints

# Editor temporaries
*.swp
*~

# Sublime Text settings
*.sublime-workspace
*.sublime-project

# Eclipse Project settings
*.*project
.settings

# QtCreator files
*.user

# PyCharm files
.idea

# Visual Studio Code files
.vscode

# OSX dir files
.DS_Store

# personal
*.log
*.pth
*.caffemodel
exp/
summary/
__pycache__/
# data/
back/
*.png
*.jpg
*.log
*.pth
events*
config/
initmodel/
*.ninja_deps
*.ninja_log
*.ninja
*.yaml


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 Hengshuang Zhao

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# PointWeb: Enhancing Local Neighborhood Features for Point Cloud Processing

by Hengshuang Zhao\*, Li Jiang*, Chi-Wing Fu, and Jiaya Jia, details are in [paper](http://openaccess.thecvf.com/content_CVPR_2019/papers/Zhao_PointWeb_Enhancing_Local_Neighborhood_Features_for_Point_Cloud_Processing_CVPR_2019_paper.pdf).

### Introduction

This repository is build for PointWeb in point cloud scene understanding.

<img src="./figure/pointweb.jpg" width="900"/>

### Usage

1. Requirement:

   - Hardware: 4 GPUs (better with >=11G GPU memory)
   - Software: PyTorch>=1.0.0, Python3, CUDA>=9.0, [tensorboardX](https://github.com/lanpa/tensorboardX)

2. Clone the repository and build the ops:

   ```shell
   git clone https://github.com/hszhao/PointWeb.git
   cd PointWeb
   cd lib/pointops && python setup.py install && cd ../../
   ```

3. Train:

   - Download related [datasets](https://drive.google.com/open?id=1Jpi2IP58zHs6Ppv05kqvwJhBnl-Kge2q) and symlink the paths to them as follows (you can alternatively modify the relevant paths specified in folder `config`):

     ```
     mkdir -p dataset
     ln -s /path_to_s3dis_dataset dataset/s3dis
     ```

   - Specify the gpu used in config and then do training:

     ```shell
     sh tool/train.sh s3dis pointweb
     ```

4. Test:

   - Download trained segmentation models and put them under folder specified in config or modify the specified paths.

   - For full testing (get listed performance):

     ```shell
     sh tool/test.sh s3dis pointweb
     ```

5. Visualization: [tensorboardX](https://github.com/lanpa/tensorboardX) incorporated for better visualization.

   ```shell
   tensorboard --logdir=run1:$EXP1,run2:$EXP2 --port=6789
   ```

6. Other:

   - Resources: GoogleDrive [LINK](https://drive.google.com/open?id=1IFoKe5TM3ZO38LT4VXCaHKvCNkXfgtBf) contains shared models, predictions and part of the related datasets.
   - Video predictions: Youtube [LINK](https://youtu.be/CaobqpsUP_4).

### Performance

Description: **mIoU/mAcc/aAcc/voxAcc** stands for mean IoU, mean accuracy of each class, all pixel accuracy , and voxel label accuracy respectively. 

mIoU/mAcc/aAcc of PointWeb on S3DIS dataset: 0.6055/0.6682/0.8658.

mIoU/mAcc/aAcc/voxAcc of PointWeb on ScanNet dataset: 0.5063/0.6061/0.8529/0.8568.

### Citation

If you find the code or trained models useful, please consider citing:

```
@inproceedings{zhao2019pointweb,
  title={{PointWeb}: Enhancing Local Neighborhood Features for Point Cloud Processing},
  author={Zhao, Hengshuang and Jiang, Li and Fu, Chi-Wing and Jia, Jiaya},
  booktitle={CVPR},
  year={2019}
}
```


================================================
FILE: data/s3dis/s3dis_names.txt
================================================
ceiling
floor
wall
beam
column
window
door
chair
table
bookcase
sofa
board
clutter


================================================
FILE: data/scannet/scannet_names.txt
================================================
bathtub
bed
bookshelf
cabinet
chair
counter
curtain
desk
door
floor
otherfurniture
picture
refrigerator
showercurtain
sink
sofa
table
toilet
wall
window


================================================
FILE: lib/__init__.py
================================================


================================================
FILE: lib/pointops/__init__.py
================================================


================================================
FILE: lib/pointops/functions/__init__.py
================================================


================================================
FILE: lib/pointops/functions/pointops.py
================================================
from typing import Tuple

import torch
from torch.autograd import Function
import torch.nn as nn

import pointops_cuda


class FurthestSampling(Function):
    @staticmethod
    def forward(ctx, xyz, m):
        """
        input: xyz: (b, n, 3) and n > m, m: int32
        output: idx: (b, m)
        """
        assert xyz.is_contiguous()
        b, n, _ = xyz.size()
        idx = torch.cuda.IntTensor(b, m)
        temp = torch.cuda.FloatTensor(b, n).fill_(1e10)
        pointops_cuda.furthestsampling_cuda(b, n, m, xyz, temp, idx)
        return idx

    @staticmethod
    def backward(xyz, a=None):
        return None, None

furthestsampling = FurthestSampling.apply


class Gathering(Function):
    @staticmethod
    def forward(ctx, features, idx):
        """
        input: features: (b, c, n), idx : (b, m) tensor
        output: (b, c, m)
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        b, c, n = features.size()
        m = idx.size(1)
        output = torch.cuda.FloatTensor(b, c, m)
        pointops_cuda.gathering_forward_cuda(b, c, n, m, features, idx, output)
        ctx.for_backwards = (idx, c, n)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        idx, c, n = ctx.for_backwards
        b, m = idx.size()
        grad_features = torch.cuda.FloatTensor(b, c, n).zero_()
        grad_out_data = grad_out.data.contiguous()
        pointops_cuda.gathering_backward_cuda(b, c, n, m, grad_out_data, idx, grad_features.data)
        return grad_features, None

gathering = Gathering.apply


class NearestNeighbor(Function):
    @staticmethod
    def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Find the three nearest neighbors of unknown in known
        input: unknown: (b, n, 3), known: (b, m, 3)
        output: dist2: (b, n, 3) l2 distance to the three nearest neighbors
                idx: (b, n, 3) index of 3 nearest neighbors
        """
        assert unknown.is_contiguous()
        assert known.is_contiguous()
        b, n, _ = unknown.size()
        m = known.size(1)
        dist2 = torch.cuda.FloatTensor(b, n, 3)
        idx = torch.cuda.IntTensor(b, n, 3)
        pointops_cuda.nearestneighbor_cuda(b, n, m, unknown, known, dist2, idx)
        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None

nearestneighbor = NearestNeighbor.apply


class Interpolation(Function):
    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """
        Performs weight linear interpolation on 3 features
        input: features: (b, c, m) features descriptors to be interpolated from
               idx: (b, n, 3) three nearest neighbors of the target features in features
               weight: (b, n, 3) weights
        output: (b, c, n) tensor of the interpolated features
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        assert weight.is_contiguous()
        b, c, m = features.size()
        n = idx.size(1)
        ctx.interpolation_for_backward = (idx, weight, m)
        output = torch.cuda.FloatTensor(b, c, n)
        pointops_cuda.interpolation_forward_cuda(b, c, m, n, features, idx, weight, output)
        return output

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        input: grad_out: (b, c, n)
        output: grad_features: (b, c, m), None, None
        """
        idx, weight, m = ctx.interpolation_for_backward
        b, c, n = grad_out.size()
        grad_features = torch.cuda.FloatTensor(b, c, m).zero_()
        grad_out_data = grad_out.data.contiguous()
        pointops_cuda.interpolation_backward_cuda(b, c, n, m, grad_out_data, idx, weight, grad_features.data)
        return grad_features, None, None

interpolation = Interpolation.apply


class Grouping(Function):
    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        """
        input: features: (b, c, n), idx : (b, m, nsample) containing the indicies of features to group with
        output: (b, c, m, nsample)
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        b, c, n = features.size()
        _, m, nsample = idx.size()
        output = torch.cuda.FloatTensor(b, c, m, nsample)
        pointops_cuda.grouping_forward_cuda(b, c, n, m, nsample, features, idx, output)
        ctx.for_backwards = (idx, n)
        return output

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        input: grad_out: (b, c, m, nsample)
        output: (b, c, n), None
        """
        idx, n = ctx.for_backwards
        b, c, m, nsample = grad_out.size()
        grad_features = torch.cuda.FloatTensor(b, c, n).zero_()
        grad_out_data = grad_out.data.contiguous()
        pointops_cuda.grouping_backward_cuda(b, c, n, m, nsample, grad_out_data, idx, grad_features.data)
        return grad_features, None

grouping = Grouping.apply


class GroupingInt(Function):
    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        """
        input: features: (b, c, n), idx : (b, m, nsample) containing the indicies of features to group with
        output: (b, c, m, nsample)
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        b, c, n = features.size()
        _, m, nsample = idx.size()
        output = torch.cuda.LongTensor(b, c, m, nsample)
        pointops_cuda.grouping_int_forward_cuda(b, c, n, m, nsample, features, idx, output)
        return output

    @staticmethod
    def backward(ctx, a=None):
        return None, None

grouping_int = GroupingInt.apply


class BallQuery(Function):
    @staticmethod
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
        """
        input: radius: float, radius of the balls
               nsample: int, maximum number of features in the balls
               xyz: torch.Tensor, (b, n, 3) xyz coordinates of the features
               new_xyz: torch.Tensor, (b, m, 3) centers of the ball query
        output: (b, m, nsample) tensor with the indicies of the features that form the query balls
        """
        assert xyz.is_contiguous()
        assert new_xyz.is_contiguous()
        b, n, _ = xyz.size()
        m = new_xyz.size(1)
        idx = torch.cuda.IntTensor(b, m, nsample).zero_()
        pointops_cuda.ballquery_cuda(b, n, m, radius, nsample, new_xyz, xyz, idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None

ballquery = BallQuery.apply


class FeatureDistribute(Function):
    @staticmethod
    def forward(ctx, max_xyz: torch.Tensor, xyz: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param max_xyz: (b, n, 3)
        :param xyz: (b, m, 3)
        :return: distribute_idx: (b, m)
        """
        assert max_xyz.is_contiguous()
        assert xyz.is_contiguous()
        b, n, _ = max_xyz.size()
        m = xyz.size(1)
        distribute_idx = torch.cuda.IntTensor(b, m).zero_()
        pointops_cuda.featuredistribute_cuda(b, n, m, max_xyz, xyz, distribute_idx)
        return distribute_idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None

featuredistribute = FeatureDistribute.apply


class FeatureGather(Function):
    @staticmethod
    def forward(ctx, max_feature: torch.Tensor, distribute_idx: torch.Tensor) -> torch.Tensor:
        '''
        :param ctx:
        :param max_feature: (b, c, n)
        :param distribute_idx: (b, m)
        :return: distribute_feature: (b, c, m)
        '''
        assert max_feature.is_contiguous()
        assert distribute_idx.is_contiguous()
        b, c, n = max_feature.size()
        m = distribute_idx.size(1)
        distribute_feature = torch.cuda.FloatTensor(b, c, m).zero_()
        pointops_cuda.featuregather_forward_cuda(b, n, m, c, max_feature, distribute_idx, distribute_feature)
        ctx.for_backwards = (distribute_idx, n)
        return distribute_feature

    @staticmethod
    def backward(ctx, grad_distribute_feature: torch.Tensor):
        '''
        :param ctx:
        :param grad_distribute_feature: (b, c, m)
        :return: grad_max_feature: (b, c, n),    None
        '''
        distribute_idx, n = ctx.for_backwards
        b, c, m = grad_distribute_feature.size()
        grad_max_feature = torch.cuda.FloatTensor(b, c, n).zero_()
        grad_distribute_feature_data = grad_distribute_feature.data.contiguous()
        pointops_cuda.featuregather_backward_cuda(b, n, m, c, grad_distribute_feature_data, distribute_idx, grad_max_feature.data)
        return grad_max_feature, None

featuregather = FeatureGather.apply


class LabelStatBallRange(Function):
    @staticmethod
    def forward(ctx, radius: float, xyz: torch.Tensor, new_xyz: torch.Tensor, label_stat: torch.Tensor) -> torch.Tensor:
        '''
        :param ctx:
        :param radius:
        :param xyz: (b, n, 3)
        :param new_xyz: (b, m, 3)
        :param label_stat: (b, n, nclass)
        :return: new_label_stat: (b, m, nclass)
        '''
        assert xyz.is_contiguous()
        assert new_xyz.is_contiguous()
        assert label_stat.is_contiguous()

        b, n, nclass = label_stat.size()
        m = new_xyz.size(1)
        new_label_stat = torch.cuda.IntTensor(b, m, nclass).zero_()
        pointops_cuda.labelstat_ballrange_cuda(b, n, m, radius, nclass, new_xyz, xyz, label_stat, new_label_stat)

        return new_label_stat

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None

labelstat_ballrange = LabelStatBallRange.apply


class LabelStatIdx(Function):
    @staticmethod
    def forward(ctx, nsample: int, label_stat: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        '''
        :param ctx:
        :param nsample:
        :param label_stat: (b, n, nclass)
        :param idx: (b, m, nsample)
        :return: new_label_stat: (b, m, nclass)
        '''
        assert label_stat.is_contiguous()
        assert idx.is_contiguous()

        b, n, nclass = label_stat.size()
        m = idx.size(1)
        new_label_stat = torch.cuda.IntTensor(b, m, nclass).zero_()
        pointops_cuda.labelstat_idx_cuda(b, n, m, nsample, nclass, label_stat, idx, new_label_stat)

        return new_label_stat

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None

labelstat_idx = LabelStatIdx.apply


class LabelStatAndBallQuery(Function):
    @staticmethod
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor, label_stat: torch.Tensor):
        '''
        :param ctx:
        :param radius:
        :param nsample:
        :param xyz: (b, n, 3)
        :param new_xyz: (b, m, 3)
        :param label_stat: (b, n, nclass)
        :return: new_label_stat: (b, m, nclass)  idx: (b, m, nsample)
        '''
        assert xyz.is_contiguous()
        assert new_xyz.is_contiguous()
        assert label_stat.is_contiguous()

        b, n, nclass = label_stat.size()
        m = new_xyz.size(1)
        new_label_stat = torch.cuda.IntTensor(b, m, nclass).zero_()
        idx = torch.cuda.IntTensor(b, m, nsample).zero_()

        pointops_cuda.labelstat_and_ballquery_cuda(b, n, m, radius, nsample, nclass, new_xyz, xyz, label_stat, idx, new_label_stat)

        return new_label_stat, idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None, None, None, None

labelstat_and_ballquery = LabelStatAndBallQuery.apply


def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x ** 2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    import numpy as np
    return torch.clamp(dist, 0.0, np.inf)


class KNNQueryNaive(Function):
    @staticmethod
    def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor = None) -> Tuple[torch.Tensor]:
        """
        KNN Indexing
        input: nsample: int32, Number of neighbor
               xyz: (b, n, 3) coordinates of the features
               new_xyz: (b, m, 3) centriods
            output: idx: (b, m, nsample)
        """
        if new_xyz is None:
            new_xyz = xyz
        b, m, _ = new_xyz.size()
        n = xyz.size(1)

        '''
        idx = torch.zeros(b, m, nsample).int().cuda()
        for i in range(b):
            dist = pairwise_distances(new_xyz[i, :, :], xyz[i, :, :])
            [_, idxs] = torch.sort(dist, dim=1)
            idx[i, :, :] = idxs[:, 0:nsample]
        '''

        # '''
        # new_xyz_repeat = new_xyz.repeat(1, 1, n).view(b, m * n, 3)
        # xyz_repeat = xyz.repeat(1, m, 1).view(b, m * n, 3)
        # dist = (new_xyz_repeat - xyz_repeat).pow(2).sum(dim=2).view(b, m, n)
        dist = (new_xyz.repeat(1, 1, n).view(b, m * n, 3) - xyz.repeat(1, m, 1).view(b, m * n, 3)).pow(2).sum(dim=2).view(b, m, n)
        [_, idxs] = torch.sort(dist, dim=2)
        idx = idxs[:, :, 0:nsample].int()
        # '''
        return idx

    @staticmethod
    def backward(ctx):
        return None, None, None

knnquery_naive = KNNQueryNaive.apply


class KNNQuery(Function):
    @staticmethod
    def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor = None) -> Tuple[torch.Tensor]:
        """
        KNN Indexing
        input: nsample: int32, Number of neighbor
               xyz: (b, n, 3) coordinates of the features
               new_xyz: (b, m, 3) centriods
            output: idx: (b, m, nsample)
                   ( dist2: (b, m, nsample) )
        """
        if new_xyz is None:
            new_xyz = xyz
        assert xyz.is_contiguous()
        assert new_xyz.is_contiguous()
        b, m, _ = new_xyz.size()
        n = xyz.size(1)
        idx = torch.cuda.IntTensor(b, m, nsample).zero_()
        dist2 = torch.cuda.FloatTensor(b, m, nsample).zero_()
        pointops_cuda.knnquery_cuda(b, n, m, nsample, xyz, new_xyz, idx, dist2)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None

knnquery = KNNQuery.apply


class KNNQueryExclude(Function):
    @staticmethod
    def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor = None) -> Tuple[torch.Tensor]:
        """
        KNN Indexing
        input: nsample: int32, Number of neighbor
               xyz: (b, n, 3) coordinates of the features
               new_xyz: (b, m, 3) centriods
            output: new_features: (b, m, nsample)
        """
        if new_xyz is None:
            new_xyz = xyz
        b, m, _ = new_xyz.size()
        n = xyz.size(1)

        '''
        idx = torch.zeros(b, m, nsample).int().cuda()
        for i in range(b):
            dist = pairwise_distances(new_xyz[i, :, :], xyz[i, :, :])
            [_, idxs] = torch.sort(dist, dim=1)
            idx[i, :, :] = idxs[:, 0:nsample]
        '''

        # '''
        # new_xyz_repeat = new_xyz.repeat(1, 1, n).view(b, m * n, 3)
        # xyz_repeat = xyz.repeat(1, m, 1).view(b, m * n, 3)
        # dist = (new_xyz_repeat - xyz_repeat).pow(2).sum(dim=2).view(b, m, n)
        dist = (new_xyz.repeat(1, 1, n).view(b, m * n, 3) - xyz.repeat(1, m, 1).view(b, m * n, 3)).pow(2).sum(dim=2).view(b, m, n)
        [_, idxs] = torch.sort(dist, dim=2)
        idx = idxs[:, :, 1:nsample+1].int()
        # '''
        return idx

    @staticmethod
    def backward(ctx):
        return None, None, None

knnquery_exclude = KNNQueryExclude.apply


class QueryAndGroup(nn.Module):
    """
    Groups with a ball query of radius
    parameters:
        radius: float32, Radius of ball
        nsample: int32, Maximum number of features to gather in the ball
    """
    def __init__(self, radius=None, nsample=32, use_xyz=True):
        super(QueryAndGroup, self).__init__()
        self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz

    def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor = None, features: torch.Tensor = None, idx: torch.Tensor = None) -> torch.Tensor:
        """
        input: xyz: (b, n, 3) coordinates of the features
               new_xyz: (b, m, 3) centriods
               features: (b, c, n)
               idx: idx of neighbors
               # idxs: (b, n)
        output: new_features: (b, c+3, m, nsample)
              #  grouped_idxs: (b, m, nsample)
        """
        if new_xyz is None:
            new_xyz = xyz
        if idx is None:
            if self.radius is not None:
                idx = ballquery(self.radius, self.nsample, xyz, new_xyz)
            else:
                # idx = knnquery_naive(self.nsample, xyz, new_xyz)   # (b, m, nsample)
                idx = knnquery(self.nsample, xyz, new_xyz)  # (b, m, nsample)
        xyz_trans = xyz.transpose(1, 2).contiguous()
        grouped_xyz = grouping(xyz_trans, idx)  # (b, 3, m, nsample)
        # grouped_idxs = grouping(idxs.unsqueeze(1).float(), idx).squeeze(1).int()  # (b, m, nsample)

        grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
        if features is not None:
            grouped_features = grouping(features, idx)
            if self.use_xyz:
                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (b, c+3, m, nsample)
            else:
                new_features = grouped_features
        else:
            assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
            new_features = grouped_xyz
        return new_features


class GroupAll(nn.Module):
    """
    Groups all features
    """
    def __init__(self, use_xyz: bool = True):
        super(GroupAll, self).__init__()
        self.use_xyz = use_xyz

    def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
        """
        input: xyz: (b, n, 3) coordinates of the features
               new_xyz: ignored torch
               features: (b, c, n) descriptors of the features
        output: new_features: (b, c+3, 1, N) tensor
        """
        grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
        if features is not None:
            grouped_features = features.unsqueeze(2)
            if self.use_xyz:
                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (b, c+3, 1, n)
            else:
                new_features = grouped_features
        else:
            new_features = grouped_xyz
        return new_features




================================================
FILE: lib/pointops/setup.py
================================================
#python3 setup.py install
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='pointops',
    ext_modules=[
        CUDAExtension('pointops_cuda', [
            'src/pointops_api.cpp',

            'src/ballquery/ballquery_cuda.cpp',
            'src/ballquery/ballquery_cuda_kernel.cu',
            'src/knnquery/knnquery_cuda.cpp',
            'src/knnquery/knnquery_cuda_kernel.cu',
            'src/grouping/grouping_cuda.cpp',
            'src/grouping/grouping_cuda_kernel.cu',
            'src/grouping_int/grouping_int_cuda.cpp',
            'src/grouping_int/grouping_int_cuda_kernel.cu',
            'src/interpolation/interpolation_cuda.cpp',
            'src/interpolation/interpolation_cuda_kernel.cu',
            'src/sampling/sampling_cuda.cpp',
            'src/sampling/sampling_cuda_kernel.cu',

            'src/labelstat/labelstat_cuda.cpp',
            'src/labelstat/labelstat_cuda_kernel.cu',

            'src/featuredistribute/featuredistribute_cuda.cpp',
            'src/featuredistribute/featuredistribute_cuda_kernel.cu'
        ],
                      extra_compile_args={'cxx': ['-g'],
                                          'nvcc': ['-O2']})
    ],
    cmdclass={'build_ext': BuildExtension})


================================================
FILE: lib/pointops/src/__init__.py
================================================


================================================
FILE: lib/pointops/src/ballquery/ballquery_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>

#include "ballquery_cuda_kernel.h"

extern THCState *state;

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

void ballquery_cuda(int b, int n, int m, float radius, int nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor)
{
    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    ballquery_cuda_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx);
}


void ballquery_cuda_fast(int b, int n, int m, float radius, int nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor)
{
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);

    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    ballquery_cuda_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
}


================================================
FILE: lib/pointops/src/ballquery/ballquery_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "ballquery_cuda_kernel.h"

// input: new_xyz(b, m, 3) xyz(b, n, 3)
// output: idx(b, m, nsample)
__global__ void ballquery_cuda_kernel(int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)
{
    int batch_index = blockIdx.x;
    xyz += batch_index * n * 3;
    new_xyz += batch_index * m * 3;
    idx += m * nsample * batch_index;
    int index = threadIdx.x;
    int stride = blockDim.x;

    float radius2 = radius * radius;
    for (int j = index; j < m; j += stride)
    {
        float new_x = new_xyz[j * 3 + 0];
        float new_y = new_xyz[j * 3 + 1];
        float new_z = new_xyz[j * 3 + 2];
        for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k)
        {
            float x = xyz[k * 3 + 0];
            float y = xyz[k * 3 + 1];
            float z = xyz[k * 3 + 2];
            float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
            if (d2 < radius2)
            {
                if (cnt == 0)
                {
                    for (int l = 0; l < nsample; ++l)
                        idx[j * nsample + l] = k;
                }
                idx[j * nsample + cnt] = k;
                ++cnt;
            }
        }
    }
}

void ballquery_cuda_launcher(int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)
{
    ballquery_cuda_kernel<<<b, opt_n_threads(m), 0>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
}


__global__ void ballquery_cuda_kernel_fast(int b, int n, int m, float radius, int nsample, const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            if (cnt == 0){
                for (int l = 0; l < nsample; ++l) {
                    idx[l] = k;
                }
            }
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample){
                break;
            }
        }
    }
}


void ballquery_cuda_launcher_fast(int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
    // param new_xyz: (B, m, 3)
    // param xyz: (B, n, 3)
    // param idx: (B, m, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    ballquery_cuda_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


================================================
FILE: lib/pointops/src/ballquery/ballquery_cuda_kernel.h
================================================
#ifndef _BALLQUERY_CUDA_KERNEL
#define _BALLQUERY_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void ballquery_cuda(int b, int n, int m, float radius, int nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);

void ballquery_cuda_fast(int b, int n, int m, float radius, int nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void ballquery_cuda_launcher(int b, int n, int m, float radius, int nsample, const float *xyz, const float *new_xyz, int *idx);

void ballquery_cuda_launcher_fast(int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif


================================================
FILE: lib/pointops/src/cuda_utils.h
================================================
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#include <cmath>

#define TOTAL_THREADS 1024

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))

inline int opt_n_threads(int work_size) {
    const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
    return max(min(1 << pow_2, TOTAL_THREADS), 1);
}

inline dim3 opt_block_config(int x, int y) {
    const int x_threads = opt_n_threads(x);
    const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
    dim3 block_config(x_threads, y_threads, 1);
    return block_config;
}

#endif

================================================
FILE: lib/pointops/src/featuredistribute/featuredistribute_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>

#include "featuredistribute_cuda_kernel.h"

extern THCState *state;

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)


void featuredistribute_cuda(int b, int n, int m, at::Tensor max_xyz_tensor, at::Tensor xyz_tensor, at::Tensor distribute_idx_tensor)
{
    CHECK_INPUT(max_xyz_tensor);
    CHECK_INPUT(xyz_tensor);

    const float *max_xyz = max_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *distribute_idx = distribute_idx_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    featuredistribute_cuda_launcher(b, n, m, max_xyz, xyz, distribute_idx, stream);
}


void featuregather_forward_cuda(int b, int n, int m, int c, at::Tensor max_feature_tensor, at::Tensor distribute_idx_tensor, at::Tensor distribute_feature_tensor)
{
    CHECK_INPUT(max_feature_tensor);
    CHECK_INPUT(distribute_idx_tensor);

    const float *max_feature = max_feature_tensor.data<float>();
    const int *distribute_idx = distribute_idx_tensor.data<int>();
    float *distribute_feature = distribute_feature_tensor.data<float>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    featuregather_forward_cuda_launcher(b, n, m, c, max_feature, distribute_idx, distribute_feature, stream);
}


void featuregather_backward_cuda(int b, int n, int m, int c, at::Tensor grad_distribute_feature_tensor, at::Tensor distribute_idx_tensor, at::Tensor grad_max_feature_tensor)
{
    CHECK_INPUT(grad_distribute_feature_tensor);
    CHECK_INPUT(distribute_idx_tensor);

    const float *grad_distribute_feature = grad_distribute_feature_tensor.data<float>();
    const int *distribute_idx = distribute_idx_tensor.data<int>();
    float *grad_max_feature = grad_max_feature_tensor.data<float>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    featuregather_backward_cuda_launcher(b, n, m, c, grad_distribute_feature, distribute_idx, grad_max_feature, stream);
}

================================================
FILE: lib/pointops/src/featuredistribute/featuredistribute_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "featuredistribute_cuda_kernel.h"

__global__ void featuredistribute_cuda_kernel(int b, int n, int m, const float *max_xyz, const float *xyz, int *distribute_idx) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    max_xyz += bs_idx * n * 3;
    xyz += bs_idx * m * 3 + pt_idx * 3;
    distribute_idx += bs_idx * m + pt_idx;

    float x = xyz[0];
    float y = xyz[1];
    float z = xyz[2];

    float min_dist2 = 100000;
    int min_dist_idx = -1;
    for (int k = 0; k < n; ++k) {
        float max_x = max_xyz[k * 3 + 0];
        float max_y = max_xyz[k * 3 + 1];
        float max_z = max_xyz[k * 3 + 2];
        float d2 = (max_x - x) * (max_x - x) + (max_y - y) * (max_y - y) + (max_z - z) * (max_z - z);
        if (d2 < min_dist2){
            min_dist_idx = k;
            min_dist2 = d2;
        }
    }
    distribute_idx[0] = min_dist_idx;
}


void featuredistribute_cuda_launcher(int b, int n, int m, const float *max_xyz, const float *xyz, int *distribute_idx, cudaStream_t stream) {
    // param max_xyz: (b, n, 3)
    // param xyz: (b, m, 3)
    // return distribute_idx: (b, m)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    featuredistribute_cuda_kernel<<<blocks, threads, 0, stream>>>(b, n, m, max_xyz, xyz, distribute_idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

__global__ void featuregather_forward_cuda_kernel(int b, int n, int m, int c, const float *max_feature, const int *distribute_idx, float *distribute_feature) {
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

    max_feature += bs_idx * c * n + c_idx * n;
    distribute_idx += bs_idx * m + pt_idx;
    distribute_feature += bs_idx * c * m + c_idx * m + pt_idx;

    int idx = distribute_idx[0];
    distribute_feature[0] = max_feature[idx];
}


void featuregather_forward_cuda_launcher(int b, int n, int m, int c, const float *max_feature, const int *distribute_idx, float *distribute_feature, cudaStream_t stream){
    // param max_feature: (b, c, n)
    // param distribute_idx: (b, m)
    // return distribute_feature: (b, c, m)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    featuregather_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(b, n, m, c, max_feature, distribute_idx, distribute_feature);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__global__ void featuregather_backward_cuda_kernel(int b, int n, int m, int c, const float *grad_distribute_feature, const int *distribute_idx, float *grad_max_feature){
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(bs_idx >= b || c_idx >= c || pt_idx >= m) return;

    grad_distribute_feature += bs_idx * c * m + c_idx * m + pt_idx;
    distribute_idx += bs_idx * m + pt_idx;
    grad_max_feature += bs_idx * c * n + c_idx * n;

    int idx = distribute_idx[0];
    atomicAdd(grad_max_feature + idx, grad_distribute_feature[0]);
}


void featuregather_backward_cuda_launcher(int b, int n, int m, int c, const float *grad_distribute_feature, const int *distribute_idx, float *grad_max_feature, cudaStream_t stream){
    // param grad_distribute_feature: (b, c, m)
    // param distribute_idx: (b, m)
    // return grad_max_feature: (b, c, n)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    featuregather_backward_cuda_kernel<<<blocks, threads, 0, stream>>>(b, n, m, c, grad_distribute_feature, distribute_idx, grad_max_feature);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: lib/pointops/src/featuredistribute/featuredistribute_cuda_kernel.h
================================================
#ifndef _FEATUREDISTRIBUTE_CUDA_KERNEL
#define _FEATUREDISTRIBUTE_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void featuredistribute_cuda(int b, int n, int m, at::Tensor max_xyz_tensor, at::Tensor xyz_tensor, at::Tensor distribute_idx_tensor);
void featuregather_forward_cuda(int b, int n, int m, int c, at::Tensor max_feature_tensor, at::Tensor distribute_idx_tensor, at::Tensor distribute_feature_tensor);
void featuregather_backward_cuda(int b, int n, int m, int c, at::Tensor grad_distribute_feature_tensor, at::Tensor distribute_idx_tensor, at::Tensor grad_max_feature_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void featuredistribute_cuda_launcher(int b, int n, int m, const float *max_xyz, const float *xyz, int *distribute_idx, cudaStream_t stream);
void featuregather_forward_cuda_launcher(int b, int n, int m, int c, const float *max_feature, const int *distribute_idx, float *distribute_feature, cudaStream_t stream);
void featuregather_backward_cuda_launcher(int b, int n, int m, int c, const float *grad_distribute_feature, const int *distribute_idx, float *grad_max_feature, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif


================================================
FILE: lib/pointops/src/grouping/grouping_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>

#include "grouping_cuda_kernel.h"

extern THCState *state;

void grouping_forward_cuda(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor)
{
    const float *points = points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();
    grouping_forward_cuda_launcher(b, c, n, m, nsample, points, idx, out);
}

void grouping_backward_cuda(int b, int c, int n, int m, int nsample, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor)
{
    float *grad_points = grad_points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    const float *grad_out = grad_out_tensor.data<float>();
    grouping_backward_cuda_launcher(b, c, n, m, nsample, grad_out, idx, grad_points);
}

void grouping_forward_cuda_fast(int b, int c, int n, int npoints, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {

    const float *points = points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();
    grouping_forward_cuda_launcher_fast(b, c, n, npoints, nsample, points, idx, out);
}

================================================
FILE: lib/pointops/src/grouping/grouping_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "grouping_cuda_kernel.h"

// input: points(b, c, n) idx(b, m, nsample)
// output: out(b, c, m, nsample)
__global__ void grouping_forward_cuda_kernel(int b, int c, int n, int m, int nsample, const float *points, const int *idx, float *out)
{
    int batch_index = blockIdx.x;
    points += batch_index * n * c;
    idx += batch_index * m * nsample;
    out += batch_index * m * nsample * c;
    const int index = threadIdx.y * blockDim.x + threadIdx.x;
    const int stride = blockDim.y * blockDim.x;
    for (int i = index; i < c * m; i += stride)
    {
        const int l = i / m;
        const int j = i % m;
        for (int k = 0; k < nsample; ++k)
        {
            int ii = idx[j * nsample + k];
            out[(l * m + j) * nsample + k] = points[l * n + ii];
        }
    }
}

// input: grad_out(b, c, m, nsample), idx(b, m, nsample)
// output: grad_points(b, c, n)
__global__ void grouping_backward_cuda_kernel(int b, int c, int n, int m, int nsample, const float *grad_out, const int *idx, float *grad_points)
{
    int batch_index = blockIdx.x;
    grad_out += batch_index * m * nsample * c;
    idx += batch_index * m * nsample;
    grad_points += batch_index * n * c;
    const int index = threadIdx.y * blockDim.x + threadIdx.x;
    const int stride = blockDim.y * blockDim.x;
    for (int i = index; i < c * m; i += stride)
    {
        const int l = i / m;
        const int j = i % m;
        for (int k = 0; k < nsample; ++k)
        {
            int ii = idx[j * nsample + k];
            atomicAdd(grad_points + l * n + ii, grad_out[(l * m + j) * nsample + k]);
        }
    }
}

void grouping_forward_cuda_launcher(int b, int c, int n, int m, int nsample, const float *points, const int *idx, float *out)
{
    grouping_forward_cuda_kernel<<<b, opt_block_config(m, c), 0>>>(b, c, n, m, nsample, points, idx, out);
}

void grouping_backward_cuda_launcher(int b, int c, int n, int m, int nsample, const float *grad_out, const int *idx, float *grad_points)
{
    grouping_backward_cuda_kernel<<<b, opt_block_config(m, c), 0>>>(b, c, n, m, nsample, grad_out, idx, grad_points);
}

// input: points(b, c, n) idx(b, npoints, nsample)
// output: out(b, c, npoints, nsample)
__global__ void grouping_forward_cuda_kernel_fast(int b, int c, int n, int npoints, int nsample, const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int pt_idx = index / nsample;
    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

    int sample_idx = index % nsample;

    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
    int in_idx = bs_idx * c * n + c_idx * n + idx[0];
    int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;

    out[out_idx] = points[in_idx];
}

// input: points(b, c, n) idx(b, npoints, nsample)
// output: out(b, c, npoints, nsample)
void grouping_forward_cuda_launcher_fast(int b, int c, int n, int npoints, int nsample, const float *points, const int *idx, float *out) {

    cudaError_t err;

    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    grouping_forward_cuda_kernel_fast<<<blocks, threads, 0>>>(b, c, n, npoints, nsample, points, idx, out);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}




================================================
FILE: lib/pointops/src/grouping/grouping_cuda_kernel.h
================================================
#ifndef _GROUPING_CUDA_KERNEL
#define _GROUPING_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void grouping_forward_cuda(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out);
void grouping_backward_cuda(int b, int c, int n, int m, int nsample, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);

void grouping_forward_cuda_fast(int b, int c, int n, int npoints, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void grouping_forward_cuda_launcher(int b, int c, int n, int m, int nsample, const float *points, const int *idx, float *out);
void grouping_backward_cuda_launcher(int b, int c, int n, int m, int nsample, const float *grad_out, const int *idx, float *grad_points);

void grouping_forward_cuda_launcher_fast(int b, int c, int n, int npoints, int nsample, const float *points, const int *idx, float *out);

#ifdef __cplusplus
}
#endif
#endif


================================================
FILE: lib/pointops/src/grouping_int/grouping_int_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>

#include "grouping_int_cuda_kernel.h"

extern THCState *state;

void grouping_int_forward_cuda(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor)
{
    const long int *points = points_tensor.data<long int>();
    const int *idx = idx_tensor.data<int>();
    long int *out = out_tensor.data<long int>();
    grouping_int_forward_cuda_launcher(b, c, n, m, nsample, points, idx, out);
}

void grouping_int_forward_cuda_fast(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor)
{
    const long int *points = points_tensor.data<long int>();
    const int *idx = idx_tensor.data<int>();
    long int *out = out_tensor.data<long int>();
    grouping_int_forward_cuda_launcher_fast(b, c, n, m, nsample, points, idx, out);
}

================================================
FILE: lib/pointops/src/grouping_int/grouping_int_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "grouping_int_cuda_kernel.h"

// input: points(b, c, n) idx(b, m, nsample)
// output: out(b, c, m, nsample)
__global__ void grouping_int_forward_cuda_kernel(int b, int c, int n, int m, int nsample, const long int *points, const int *idx, long int *out)
{
    int batch_index = blockIdx.x;
    points += batch_index * n * c;
    idx += batch_index * m * nsample;
    out += batch_index * m * nsample * c;
    const int index = threadIdx.y * blockDim.x + threadIdx.x;
    const int stride = blockDim.y * blockDim.x;
    for (int i = index; i < c * m; i += stride)
    {
        const int l = i / m;
        const int j = i % m;
        for (int k = 0; k < nsample; ++k)
        {
            int ii = idx[j * nsample + k];
            out[(l * m + j) * nsample + k] = points[l * n + ii];
        }
    }
}


void grouping_int_forward_cuda_launcher(int b, int c, int n, int m, int nsample, const long int *points, const int *idx, long int *out)
{
    grouping_int_forward_cuda_kernel<<<b, opt_block_config(m, c), 0>>>(b, c, n, m, nsample, points, idx, out);
}


__global__ void grouping_int_forward_cuda_kernel_fast(int b, int c, int n, int npoints, int nsample, const long int *__restrict__ points, const int *__restrict__ idx, long int *__restrict__ out)
{
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int pt_idx = index / nsample;
    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

    int sample_idx = index % nsample;

    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
    int in_idx = bs_idx * c * n + c_idx * n + idx[0];
    int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;

    out[out_idx] = points[in_idx];
}


void grouping_int_forward_cuda_launcher_fast(int b, int c, int n, int npoints, int nsample, const long int *points, const int *idx, long int *out)
{
    cudaError_t err;

    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    grouping_int_forward_cuda_kernel_fast<<<blocks, threads, 0>>>(b, c, n, npoints, nsample, points, idx, out);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: lib/pointops/src/grouping_int/grouping_int_cuda_kernel.h
================================================
#ifndef _GROUPING_INT_CUDA_KERNEL
#define _GROUPING_INT_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void grouping_int_forward_cuda(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out);

void grouping_int_forward_cuda_fast(int b, int c, int n, int m, int nsample, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void grouping_int_forward_cuda_launcher(int b, int c, int n, int m, int nsample, const long int *points, const int *idx, long int *out);

void grouping_int_forward_cuda_launcher_fast(int b, int c, int n, int npoints, int nsample, const long int *points, const int *idx, long int *out);

#ifdef __cplusplus
}
#endif
#endif


================================================
FILE: lib/pointops/src/interpolation/interpolation_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include "interpolation_cuda_kernel.h"

extern THCState *state;

void nearestneighbor_cuda(int b, int n, int m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor)
{
    const float *unknown = unknown_tensor.data<float>();
    const float *known = known_tensor.data<float>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    nearestneighbor_cuda_launcher(b, n, m, unknown, known, dist2, idx);
}

void interpolation_forward_cuda(int b, int c, int m, int n, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor)
{
    const float *points = points_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *out = out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    interpolation_forward_cuda_launcher(b, c, m, n, points, idx, weight, out);
}

void interpolation_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor)
{
    const float *grad_out = grad_out_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *grad_points = grad_points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    interpolation_backward_cuda_launcher(b, c, n, m, grad_out, idx, weight, grad_points);
}

void nearestneighbor_cuda_fast(int b, int n, int m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
    const float *unknown = unknown_tensor.data<float>();
    const float *known = known_tensor.data<float>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    nearestneighbor_cuda_launcher_fast(b, n, m, unknown, known, dist2, idx);
}

void interpolation_forward_cuda_fast(int b, int c, int m, int n, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) {

    const float *points = points_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *out = out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    interpolation_forward_cuda_launcher_fast(b, c, m, n, points, idx, weight, out);
}

================================================
FILE: lib/pointops/src/interpolation/interpolation_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "interpolation_cuda_kernel.h"

// input: unknown(b, n, 3) known(b, m, 3)
// output: dist2(b, n, 3), idx(b, n, 3)
__global__ void nearestneighbor_cuda_kernel(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)
{
    int batch_index = blockIdx.x;
    unknown += batch_index * n * 3;
    known += batch_index * m * 3;
    dist2 += batch_index * n * 3;
    idx += batch_index * n * 3;

    int index = threadIdx.x;
    int stride = blockDim.x;
    for (int j = index; j < n; j += stride)
    {
		float ux = unknown[j * 3 + 0];
		float uy = unknown[j * 3 + 1];
		float uz = unknown[j * 3 + 2];

		double best1 = 1e40, best2 = 1e40, best3 = 1e40;
		int besti1 = 0, besti2 = 0, besti3 = 0;
		for (int k = 0; k < m; ++k)
		{
		    float x = known[k * 3 + 0];
		    float y = known[k * 3 + 1];
		    float z = known[k * 3 + 2];
		    float d =
			(ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
		    if (d < best1)
		    {
				best3 = best2;
				besti3 = besti2;
				best2 = best1;
				besti2 = besti1;
				best1 = d;
				besti1 = k;
			}
			else if (d < best2)
			{
				best3 = best2;
				besti3 = besti2;
				best2 = d;
				besti2 = k;
			}
			else if (d < best3)
			{
				best3 = d;
				besti3 = k;
		    }
		}
		dist2[j * 3 + 0] = best1;
		dist2[j * 3 + 1] = best2;
		dist2[j * 3 + 2] = best3;
		idx[j * 3 + 0] = besti1;
		idx[j * 3 + 1] = besti2;
		idx[j * 3 + 2] = besti3;
    }
}

// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
// output: out(b, c, n)
__global__ void interpolation_forward_cuda_kernel(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)
{
    int batch_index = blockIdx.x;
    points += batch_index * m * c;
    idx += batch_index * n * 3;
    weight += batch_index * n * 3;
    out += batch_index * n * c;

    const int index = threadIdx.y * blockDim.x + threadIdx.x;
    const int stride = blockDim.y * blockDim.x;
    for (int i = index; i < c * n; i += stride)
    {
		const int l = i / n;
		const int j = i % n;
		float w1 = weight[j * 3 + 0];
		float w2 = weight[j * 3 + 1];
		float w3 = weight[j * 3 + 2];
		int i1 = idx[j * 3 + 0];
		int i2 = idx[j * 3 + 1];
		int i3 = idx[j * 3 + 2];
		out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + points[l * m + i3] * w3;
    }
}

// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
// output: grad_points(b, c, m)
__global__ void interpolation_backward_cuda_kernel( int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)
{
    int batch_index = blockIdx.x;
    grad_out += batch_index * n * c;
    idx += batch_index * n * 3;
    weight += batch_index * n * 3;
    grad_points += batch_index * m * c;

    const int index = threadIdx.y * blockDim.x + threadIdx.x;
    const int stride = blockDim.y * blockDim.x;
    for (int i = index; i < c * n; i += stride)
    {
		const int l = i / n;
		const int j = i % n;
		float w1 = weight[j * 3 + 0];
		float w2 = weight[j * 3 + 1];
		float w3 = weight[j * 3 + 2];
		int i1 = idx[j * 3 + 0];
		int i2 = idx[j * 3 + 1];
		int i3 = idx[j * 3 + 2];
		atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
		atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
		atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
    }
}

void nearestneighbor_cuda_launcher(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)
{
    nearestneighbor_cuda_kernel<<<b, opt_n_threads(n), 0>>>(b, n, m, unknown, known, dist2, idx);
}

void interpolation_forward_cuda_launcher(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)
{
    interpolation_forward_cuda_kernel<<<b, opt_block_config(n, c), 0>>>(b, c, m, n, points, idx, weight, out);
}

void interpolation_backward_cuda_launcher(int b, int n, int c, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)
{
    interpolation_backward_cuda_kernel<<<b, opt_block_config(n, c), 0>>>(b, n, c, m, grad_out, idx, weight, grad_points);
}


// input: unknown(b, n, 3) known(b, m, 3)
// output: dist2(b, n, 3), idx(b, n, 3)
__global__ void nearestneighbor_cuda_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {

    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= n) return;

    unknown += bs_idx * n * 3 + pt_idx * 3;
    known += bs_idx * m * 3;
    dist2 += bs_idx * n * 3 + pt_idx * 3;
    idx += bs_idx * n * 3 + pt_idx * 3;

    float ux = unknown[0];
    float uy = unknown[1];
    float uz = unknown[2];

    double best1 = 1e40, best2 = 1e40, best3 = 1e40;
    int besti1 = 0, besti2 = 0, besti3 = 0;
    for (int k = 0; k < m; ++k) {
        float x = known[k * 3 + 0];
        float y = known[k * 3 + 1];
        float z = known[k * 3 + 2];
        float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
        if (d < best1) {
            best3 = best2; besti3 = besti2;
            best2 = best1; besti2 = besti1;
            best1 = d; besti1 = k;
        }
        else if (d < best2) {
            best3 = best2; besti3 = besti2;
            best2 = d; besti2 = k;
        }
        else if (d < best3) {
            best3 = d; besti3 = k;
        }
    }
    dist2[0] = best1;
    dist2[1] = best2;
    dist2[2] = best3;

    idx[0] = besti1;
    idx[1] = besti2;
    idx[2] = besti3;
}


// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
// output: out(b, c, n)
__global__ void interpolation_forward_cuda_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {

    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

    weight += bs_idx * n * 3 + pt_idx * 3;
    points += bs_idx * c * m + c_idx * m;
    idx += bs_idx * n * 3 + pt_idx * 3;
    out += bs_idx * c * n + c_idx * n;

    out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
}


void nearestneighbor_cuda_launcher_fast(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)
{
    cudaError_t err;

    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    nearestneighbor_cuda_kernel_fast<<<blocks, threads, 0>>>(b, n, m, unknown, known, dist2, idx);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

void interpolation_forward_cuda_launcher_fast(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out) {

    cudaError_t err;

    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);
    interpolation_forward_cuda_kernel_fast<<<blocks, threads, 0>>>(b, c, m, n, points, idx, weight, out);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n",
        cudaGetErrorString(err));
        exit(-1);
    }
}


================================================
FILE: lib/pointops/src/interpolation/interpolation_cuda_kernel.h
================================================
#ifndef _INTERPOLATION_CUDA_KERNEL
#define _INTERPOLATION_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void nearestneighbor_cuda(int b, int n, int m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
void interpolation_forward_cuda(int b, int c, int m, int n, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
void interpolation_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);

void nearestneighbor_cuda_fast(int b, int n, int m, at::Tensor unknown_tensor, at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
void interpolation_forward_cuda_fast(int b, int c, int m, int n, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void nearestneighbor_cuda_launcher(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx);
void interpolation_forward_cuda_launcher(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out);
void interpolation_backward_cuda_launcher(int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points);

void nearestneighbor_cuda_launcher_fast(int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx);
void interpolation_forward_cuda_launcher_fast(int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out);

#ifdef __cplusplus
}
#endif
#endif


================================================
FILE: lib/pointops/src/knnquery/__init__.py
================================================


================================================
FILE: lib/pointops/src/knnquery/knnquery_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>

#include "knnquery_cuda_kernel.h"

extern THCState *state;

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)


void knnquery_cuda(int b, int n, int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor)
{
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);

    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    float *dist2 = dist2_tensor.data<float>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    knnquery_cuda_launcher(b, n, m, nsample, xyz, new_xyz, idx, dist2, stream);
}


================================================
FILE: lib/pointops/src/knnquery/knnquery_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "knnquery_cuda_kernel.h"

// input: xyz (b, n, 3) new_xyz (b, m, 3)
// output: idx (b, m, nsample) dist2 (b, m, nsample)
__global__ void knnquery_cuda_kernel(int b, int n, int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, int *__restrict__ idx, float *__restrict__ dist2) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;

    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    //double* best = new double[nsample];
    //int* besti = new int[nsample];
    double best[200];
    int besti[200];
    for(int i = 0; i < nsample; i++){
        best[i] = 1e40;
        besti[i] = 0;
    }
    for(int k = 0; k < n; k++){
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        for(int j = 0; j < nsample; j++){
            if(d2 < best[j]){
                for(int i = nsample - 1; i > j; i--){
                    best[i] = best[i - 1];
                    besti[i] = besti[i - 1];
                }
                best[j] = d2;
                besti[j] = k;
                break;
            }
        }
    }
    for(int i = 0; i < nsample; i++){
        idx[i] = besti[i];
        dist2[i] = best[i];
    }
    //delete []best;
    //delete []besti;
}


void knnquery_cuda_launcher(int b, int n, int m, int nsample, const float *xyz, const float *new_xyz, int *idx, float *dist2, cudaStream_t stream) {
    // param new_xyz: (B, m, 3)
    // param xyz: (B, n, 3)
    // param idx: (B, m, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    knnquery_cuda_kernel<<<blocks, threads, 0, stream>>>(b, n, m, nsample, xyz, new_xyz, idx, dist2);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: lib/pointops/src/knnquery/knnquery_cuda_kernel.h
================================================
#ifndef _KNNQUERY_CUDA_KERNEL
#define _KNNQUERY_CUDA_KERNEL

#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void knnquery_cuda(int b, int n, int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void knnquery_cuda_launcher(int b, int n, int m, int nsample, const float *xyz, const float *new_xyz, int *idx, float *dist2, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

================================================
FILE: lib/pointops/src/labelstat/labelstat_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>

#include "labelstat_cuda_kernel.h"

extern THCState *state;

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

void labelstat_idx_cuda_fast(int b, int n, int m, int nsample, int nclass,
    at::Tensor label_stat_tensor, at::Tensor idx_tensor, at::Tensor new_label_stat_tensor)
{
    CHECK_INPUT(label_stat_tensor);
    CHECK_INPUT(idx_tensor);

    const int *label_stat = label_stat_tensor.data<int>();
    const int *idx = idx_tensor.data<int>();
    int *new_label_stat = new_label_stat_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    labelstat_idx_cuda_launcher_fast(b, n, m, nsample, nclass, label_stat, idx, new_label_stat, stream);
}

void labelstat_ballrange_cuda_fast(int b, int n, int m, float radius, int nclass,
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor label_stat_tensor, at::Tensor new_label_stat_tensor)
{
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);
    CHECK_INPUT(label_stat_tensor);

    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    const int *label_stat = label_stat_tensor.data<int>();
    int *new_label_stat = new_label_stat_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    labelstat_ballrange_cuda_launcher_fast(b, n, m, radius, nclass, new_xyz, xyz, label_stat, new_label_stat, stream);
}

void labelstat_and_ballquery_cuda_fast(int b, int n, int m, float radius, int nsample, int nclass,
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor label_stat_tensor, at::Tensor idx_tensor, at::Tensor new_label_stat_tensor)
{
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);
    CHECK_INPUT(label_stat_tensor);
    CHECK_INPUT(idx_tensor);

    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    const int *label_stat = label_stat_tensor.data<int>();
    int *idx = idx_tensor.data<int>();
    int *new_label_stat = new_label_stat_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);

    labelstat_and_ballquery_cuda_launcher_fast(b, n, m, radius, nsample, nclass, new_xyz, xyz, label_stat, idx, new_label_stat, stream);
}


================================================
FILE: lib/pointops/src/labelstat/labelstat_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "labelstat_cuda_kernel.h"

// input: new_xyz(b, m, 3) xyz(b, n, 3) label_stat(b, n, nclass)
// output: idx(b, m, nsample)  new_label_stat(b, m, nclass)
__global__ void labelstat_and_ballquery_cuda_kernel_fast(int b, int n, int m, float radius, int nsample, int nclass,
    const float *new_xyz, const float *xyz, const int *label_stat, int *idx, int *new_label_stat) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;
    label_stat += bs_idx * n * nclass;
    new_label_stat += bs_idx * m * nclass + pt_idx * nclass;

    for(int i = 0; i < nclass; i++){
        new_label_stat[i] = 0;
    }

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            for(int i = 0; i < nclass; i++){
                new_label_stat[i] += label_stat[k * nclass + i];
            }
            if (cnt == 0){
                for (int l = 0; l < nsample; ++l) {
                    idx[l] = k;
                }
            }
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample){
                break;
            }
        }
    }
}

void labelstat_and_ballquery_cuda_launcher_fast(int b, int n, int m, float radius, int nsample, int nclass,
    const float *new_xyz, const float *xyz, const int *label_stat, int *idx, int *new_label_stat, cudaStream_t stream) {
    // param new_xyz: (B, m, 3)
    // param xyz: (B, n, 3)
    // param idx: (B, m, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    labelstat_and_ballquery_cuda_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, nclass, new_xyz, xyz, label_stat, idx, new_label_stat);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

// input: new_xyz(b, m, 3) xyz(b, n, 3) label_stat(b, n, nclass)
// output: new_label_stat(b, m, nclass)
__global__ void labelstat_ballrange_cuda_kernel_fast(int b, int n, int m, float radius, int nclass,
    const float *new_xyz, const float *xyz, const int *label_stat, int *new_label_stat) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    label_stat += bs_idx * n * nclass;
    new_label_stat += bs_idx * m * nclass + pt_idx * nclass;

    for(int i = 0; i < nclass; i++){
        new_label_stat[i] = 0;
    }

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            for(int i = 0; i < nclass; i++){
                new_label_stat[i] += label_stat[k * nclass + i];
            }
        }
    }
}


void labelstat_ballrange_cuda_launcher_fast(int b, int n, int m, float radius, int nclass,
    const float *new_xyz, const float *xyz, const int *label_stat, int *new_label_stat, cudaStream_t stream) {
    // param new_xyz: (B, m, 3)
    // param xyz: (B, n, 3)
    // param idx: (B, m, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    labelstat_ballrange_cuda_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nclass, new_xyz, xyz, label_stat, new_label_stat);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

// input: idx(b, m, nsample) label_stat(b, n, nclass)
// output: new_label_stat(b, m, nclass)
__global__ void labelstat_idx_cuda_kernel_fast(int b, int n, int m, int nsample, int nclass,
    const int *label_stat, const int *idx, int *new_label_stat) {
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    idx += bs_idx * m * nsample + pt_idx * nsample;
    label_stat += bs_idx * n * nclass;
    new_label_stat += bs_idx * m * nclass + pt_idx * nclass;

    for(int i = 0; i < nclass; i++){
        new_label_stat[i] = 0;
    }

    for(int k = 0; k < nsample; k++){
        const int *label_stat_k = label_stat + idx[k] * nclass;
        for(int i = 0; i < nclass; i++){
            new_label_stat[i] += label_stat_k[i];
        }
    }
}


void labelstat_idx_cuda_launcher_fast(int b, int n, int m, int nsample, int nclass,
    const int *label_stat, const int *idx, int *new_label_stat, cudaStream_t stream) {
    // param new_xyz: (B, m, 3)
    // param xyz: (B, n, 3)
    // param idx: (B, m, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    labelstat_idx_cuda_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, nsample, nclass, label_stat, idx, new_label_stat);
    // cudaDeviceSynchronize();  // for using printf in kernel function

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: lib/pointops/src/labelstat/labelstat_cuda_kernel.h
================================================
#ifndef _LABELSTAT_CUDA_KERNEL
#define _LABELSTAT_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void labelstat_and_ballquery_cuda_fast(int b, int n, int m, float radius, int nsample, int nclass,
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor label_stat_tensor, at::Tensor idx_tensor, at::Tensor new_label_stat_tensor);

void labelstat_ballrange_cuda_fast(int b, int n, int m, float radius, int nclass,
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor label_stat_tensor, at::Tensor new_label_stat_tensor);

void labelstat_idx_cuda_fast(int b, int n, int m, int nsample, int nclass,
    at::Tensor label_stat_tensor, at::Tensor idx_tensor, at::Tensor new_label_stat_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void labelstat_and_ballquery_cuda_launcher_fast(int b, int n, int m, float radius, int nsample, int nclass, \
    const float *new_xyz, const float *xyz, const int *label_stat, int *idx, int *new_label_stat, cudaStream_t stream);

void labelstat_ballrange_cuda_launcher_fast(int b, int n, int m, float radius, int nclass, \
    const float *new_xyz, const float *xyz, const int *label_stat, int *new_label_stat, cudaStream_t stream);

void labelstat_idx_cuda_launcher_fast(int b, int n, int m, int nsample, int nclass, \
    const int *label_stat, const int *idx, int *new_label_stat, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif


================================================
FILE: lib/pointops/src/pointops_api.cpp
================================================
#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#include "ballquery/ballquery_cuda_kernel.h"
#include "grouping/grouping_cuda_kernel.h"
#include "grouping_int/grouping_int_cuda_kernel.h"
#include "sampling/sampling_cuda_kernel.h"
#include "interpolation/interpolation_cuda_kernel.h"
#include "knnquery/knnquery_cuda_kernel.h"

#include "labelstat/labelstat_cuda_kernel.h"
#include "featuredistribute/featuredistribute_cuda_kernel.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("ballquery_cuda", &ballquery_cuda_fast, "ballquery_cuda_fast");   // name in python, cpp function address, docs

    m.def("knnquery_cuda", &knnquery_cuda, "knnquery_cuda");

    m.def("grouping_forward_cuda", &grouping_forward_cuda_fast, "grouping_forward_cuda_fast");
    m.def("grouping_backward_cuda", &grouping_backward_cuda, "grouping_backward_cuda");

    m.def("grouping_int_forward_cuda", &grouping_int_forward_cuda_fast, "grouping_int_forward_cuda_fast");

    m.def("gathering_forward_cuda", &gathering_forward_cuda, "gathering_forward_cuda");
    m.def("gathering_backward_cuda", &gathering_backward_cuda, "gathering_backward_cuda");
    m.def("furthestsampling_cuda", &furthestsampling_cuda, "furthestsampling_cuda");

    m.def("nearestneighbor_cuda", &nearestneighbor_cuda_fast, "nearestneighbor_cuda_fast");
    m.def("interpolation_forward_cuda", &interpolation_forward_cuda_fast, "interpolation_forward_cuda_fast");
    m.def("interpolation_backward_cuda", &interpolation_backward_cuda, "interpolation_backward_cuda");

    m.def("labelstat_idx_cuda", &labelstat_idx_cuda_fast, "labelstat_idx_cuda_fast");
    m.def("labelstat_ballrange_cuda", &labelstat_ballrange_cuda_fast, "labelstat_ballrange_cuda_fast");
    m.def("labelstat_and_ballquery_cuda", &labelstat_and_ballquery_cuda_fast, "labelstat_and_ballquery_cuda_fast");

    m.def("featuredistribute_cuda", &featuredistribute_cuda, "featuredistribute_cuda");
    m.def("featuregather_forward_cuda", &featuregather_forward_cuda, "featuregather_forward_cuda");
    m.def("featuregather_backward_cuda", &featuregather_backward_cuda, "featuregather_backward_cuda");
}


================================================
FILE: lib/pointops/src/sampling/sampling_cuda.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include "sampling_cuda_kernel.h"

extern THCState *state;

void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor)
{
    const float *points = points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();
    gathering_forward_cuda_launcher(b, c, n, m, points, idx, out);
}

void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor)
{

    const float *grad_out = grad_out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *grad_points = grad_points_tensor.data<float>();
    gathering_backward_cuda_launcher(b, c, n, m, grad_out, idx, grad_points);
}

void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor)
{
    const float *points = points_tensor.data<float>();
    float *temp = temp_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    furthestsampling_cuda_launcher(b, n, m, points, temp, idx);
}


================================================
FILE: lib/pointops/src/sampling/sampling_cuda_kernel.cu
================================================
#include "../cuda_utils.h"
#include "sampling_cuda_kernel.h"

// input: points(b, c, n) idx(b, m)
// output: out(b, c, m)
__global__ void gathering_forward_cuda_kernel(int b, int c, int n, int m, const float *points, const int *idx, float *out)
{
    for (int i = blockIdx.x; i < b; i += gridDim.x)
    {
        for (int l = blockIdx.y; l < c; l += gridDim.y)
        {
            for (int j = threadIdx.x; j < m; j += blockDim.x)
            {
                int a = idx[i * m + j];
                out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
            }
        }
    }
}

// input: grad_out(b, c, m) idx(b, m)
// output: grad_points(b, c, n)
__global__ void gathering_backward_cuda_kernel(int b, int c, int n, int m, const float *grad_out, const int *idx, float *grad_points)
{
    for (int i = blockIdx.x; i < b; i += gridDim.x)
    {
        for (int l = blockIdx.y; l < c; l += gridDim.y)
        {
            for (int j = threadIdx.x; j < m; j += blockDim.x)
            {
                int a = idx[i * m + j];
                atomicAdd(grad_points + (i * c + l) * n + a, grad_out[(i * c + l) * m + j]);
            }
        }
    }
}

void gathering_forward_cuda_launcher(int b, int c, int n, int m, const float *points, const int *idx, float *out)
{
    gathering_forward_cuda_kernel<<<dim3(b, c, 1), opt_n_threads(m), 0>>>(b, c, n, m, points, idx, out);
}

void gathering_backward_cuda_launcher(int b, int c, int n, int m, const float *grad_out, const int *idx, float *grad_points)
{
    gathering_backward_cuda_kernel<<<dim3(b, c, 1), opt_n_threads(m), 0>>>(b, c, n, m, grad_out, idx, grad_points);
}

__device__ void __update(float *dists, int *dists_i,
			 int idx1, int idx2) {
    const float v1 = dists[idx1], v2 = dists[idx2];
    const int i1 = dists_i[idx1], i2 = dists_i[idx2];
    dists[idx1] = max(v1, v2);
    dists_i[idx1] = v2 > v1 ? i2 : i1;
}

// Input dataset: (b, n, 3), tmp: (b, n)
// Ouput idxs (b, m)
template <unsigned int block_size>
__global__ void furthestsampling_cuda_kernel(int b, int n, int m, const float *dataset, float *temp, int *idxs)
{
    if (m <= 0)
	    return;
    __shared__ float dists[block_size];
    __shared__ int dists_i[block_size];

    int batch_index = blockIdx.x;
    dataset += batch_index * n * 3;
    temp += batch_index * n;
    idxs += batch_index * m;
    int tid = threadIdx.x;
    const int stride = block_size;
    int old = 0;
    if (threadIdx.x == 0)
	    idxs[0] = old;

    __syncthreads();
    for (int j = 1; j < m; j++)
    {
        int besti = 0;
        float best = -1;
        float x1 = dataset[old * 3 + 0];
        float y1 = dataset[old * 3 + 1];
        float z1 = dataset[old * 3 + 2];
        for (int k = tid; k < n; k += stride)
        {
            float x2, y2, z2;
            x2 = dataset[k * 3 + 0];
            y2 = dataset[k * 3 + 1];
            z2 = dataset[k * 3 + 2];
            //float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
            //if (mag <= 1e-3)
            //    continue;
            float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
            float d2 = min(d, temp[k]);
            temp[k] = d2;
            besti = d2 > best ? k : besti;
            best = d2 > best ? d2 : best;
        }
        dists[tid] = best;
        dists_i[tid] = besti;
        __syncthreads();

        if (block_size >= 1024) {
            if (tid < 512) {
            __update(dists, dists_i, tid, tid + 512);
            }
            __syncthreads();
        }
        if (block_size >= 512) {
            if (tid < 256) {
            __update(dists, dists_i, tid, tid + 256);
            }
            __syncthreads();
        }
        if (block_size >= 256) {
            if (tid < 128) {
            __update(dists, dists_i, tid, tid + 128);
            }
            __syncthreads();
        }
        if (block_size >= 128) {
            if (tid < 64) {
            __update(dists, dists_i, tid, tid + 64);
            }
            __syncthreads();
        }
        if (block_size >= 64) {
            if (tid < 32) {
            __update(dists, dists_i, tid, tid + 32);
            }
            __syncthreads();
        }
        if (block_size >= 32) {
            if (tid < 16) {
            __update(dists, dists_i, tid, tid + 16);
            }
            __syncthreads();
        }
        if (block_size >= 16) {
            if (tid < 8) {
            __update(dists, dists_i, tid, tid + 8);
            }
            __syncthreads();
        }
        if (block_size >= 8) {
            if (tid < 4) {
            __update(dists, dists_i, tid, tid + 4);
            }
            __syncthreads();
        }
        if (block_size >= 4) {
            if (tid < 2) {
            __update(dists, dists_i, tid, tid + 2);
            }
            __syncthreads();
        }
        if (block_size >= 2) {
            if (tid < 1) {
            __update(dists, dists_i, tid, tid + 1);
            }
            __syncthreads();
        }

        old = dists_i[0];
        if (tid == 0)
            idxs[j] = old;
    }
}

void furthestsampling_cuda_launcher(int b, int n, int m, const float *dataset, float *temp, int *idxs)
{   
	unsigned int n_threads = opt_n_threads(n);
	switch (n_threads) {
	    case 1024:
	        furthestsampling_cuda_kernel<1024><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
	        break;
		case 512:
			furthestsampling_cuda_kernel<512><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 256:
			furthestsampling_cuda_kernel<256><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 128:
			furthestsampling_cuda_kernel<128><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 64:
			furthestsampling_cuda_kernel<64><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 32:
			furthestsampling_cuda_kernel<32><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 16:
			furthestsampling_cuda_kernel<16><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
    	case 8:
			furthestsampling_cuda_kernel<8><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
	    case 4:
			furthestsampling_cuda_kernel<4><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
	    case 2:
			furthestsampling_cuda_kernel<2><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
	    case 1:
			furthestsampling_cuda_kernel<1><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
			break;
	    default:
			furthestsampling_cuda_kernel<512><<<b, n_threads, 0>>>(b, n, m, dataset, temp, idxs);
	    }
}


================================================
FILE: lib/pointops/src/sampling/sampling_cuda_kernel.h
================================================
#ifndef _SAMPLING_CUDA_KERNEL
#define _SAMPLING_CUDA_KERNEL
#include <torch/serialize/tensor.h>
#include <vector>
#include <ATen/cuda/CUDAContext.h>

void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);

#ifdef __cplusplus
extern "C" {
#endif

void gathering_forward_cuda_launcher(int b, int c, int n, int m, const float *points, const int *idx, float *out);
void gathering_backward_cuda_launcher(int b, int c, int n, int m, const float *grad_out, const int *idx, float *grad_points);
void furthestsampling_cuda_launcher(int b, int n, int m, const float *dataset, float *temp, int *idxs);

#ifdef __cplusplus
}
#endif
#endif


================================================
FILE: lib/sync_bn/__init__.py
================================================
# -*- coding: utf-8 -*-
# File   : __init__.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
# 
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .replicate import DataParallelWithCallback, patch_replication_callback


================================================
FILE: lib/sync_bn/batchnorm.py
================================================
# -*- coding: utf-8 -*-
# File   : batchnorm.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
# 
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import collections

import torch
import torch.nn.functional as F

from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

from .comm import SyncMaster

__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']


def _sum_ft(tensor):
    """sum over the first and last dimention"""
    return tensor.sum(dim=0).sum(dim=-1)


def _unsqueeze_ft(tensor):
    """add new dementions at the front and the tail"""
    return tensor.unsqueeze(0).unsqueeze(-1)


_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])


class _SynchronizedBatchNorm(_BatchNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, input):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            return F.batch_norm(
                input, self.running_mean, self.running_var, self.weight, self.bias,
                self.training, self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input ** 2)

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))

        # Compute the output.
        if self.affine:
            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].sum.get_device() for i in intermediates]

        sum_size = sum([i[1].sum_size for i in intermediates])
        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)

        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_std(self, sum_, ssum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        mean = sum_ / size
        sumvar = ssum - sum_ * mean
        unbias_var = sumvar / (size - 1)
        bias_var = sumvar / size

        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data

        return mean, bias_var.clamp(self.eps) ** -0.5


class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
    r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
    mini-batch.

    .. math::

        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta

    This module differs from the built-in PyTorch BatchNorm1d as the mean and
    standard-deviation are reduced across all devices during training.

    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm

    Args:
        num_features: num_features from an expected input of size
            `batch_size x num_features [x width]`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm1d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm1d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)


class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
    r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
    of 3d inputs

    .. math::

        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta

    This module differs from the built-in PyTorch BatchNorm2d as the mean and
    standard-deviation are reduced across all devices during training.

    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm2d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm2d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)


class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
    r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
    of 4d inputs

    .. math::

        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta

    This module differs from the built-in PyTorch BatchNorm3d as the mean and
    standard-deviation are reduced across all devices during training.

    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
    or Spatio-temporal BatchNorm

    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x depth x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``

    Shape:
        - Input: :math:`(N, C, D, H, W)`
        - Output: :math:`(N, C, D, H, W)` (same shape as input)

    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm3d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm3d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)


================================================
FILE: lib/sync_bn/comm.py
================================================
# -*- coding: utf-8 -*-
# File   : comm.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
# 
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import queue
import collections
import threading

__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']


class FutureResult(object):
    """A thread-safe future implementation. Used only as one-to-one pipe."""

    def __init__(self):
        self._result = None
        self._lock = threading.Lock()
        self._cond = threading.Condition(self._lock)

    def put(self, result):
        with self._lock:
            assert self._result is None, 'Previous result has\'t been fetched.'
            self._result = result
            self._cond.notify()

    def get(self):
        with self._lock:
            if self._result is None:
                self._cond.wait()

            res = self._result
            self._result = None
            return res


_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])


class SlavePipe(_SlavePipeBase):
    """Pipe for master-slave communication."""

    def run_slave(self, msg):
        self.queue.put((self.identifier, msg))
        ret = self.result.get()
        self.queue.put(True)
        return ret


class SyncMaster(object):
    """An abstract `SyncMaster` object.

    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
    call `register(id)` and obtain an `SlavePipe` to communicate with the master.
    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
    and passed to a registered callback.
    - After receiving the messages, the master device should gather the information and determine to message passed
    back to each slave devices.
    """

    def __init__(self, master_callback):
        """

        Args:
            master_callback: a callback to be invoked after having collected messages from slave devices.
        """
        self._master_callback = master_callback
        self._queue = queue.Queue()
        self._registry = collections.OrderedDict()
        self._activated = False

    def __getstate__(self):
        return {'master_callback': self._master_callback}

    def __setstate__(self, state):
        self.__init__(state['master_callback'])

    def register_slave(self, identifier):
        """
        Register an slave device.

        Args:
            identifier: an identifier, usually is the device id.

        Returns: a `SlavePipe` object which can be used to communicate with the master device.

        """
        if self._activated:
            assert self._queue.empty(), 'Queue is not clean before next initialization.'
            self._activated = False
            self._registry.clear()
        future = FutureResult()
        self._registry[identifier] = _MasterRegistry(future)
        return SlavePipe(identifier, self._queue, future)

    def run_master(self, master_msg):
        """
        Main entry for the master device in each forward pass.
        The messages were first collected from each devices (including the master device), and then
        an callback will be invoked to compute the message to be sent back to each devices
        (including the master device).

        Args:
            master_msg: the message that the master want to send to itself. This will be placed as the first
            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.

        Returns: the message to be sent back to the master device.

        """
        self._activated = True

        intermediates = [(0, master_msg)]
        for i in range(self.nr_slaves):
            intermediates.append(self._queue.get())

        results = self._master_callback(intermediates)
        assert results[0][0] == 0, 'The first result should belongs to the master.'

        for i, res in results:
            if i == 0:
                continue
            self._registry[i].result.put(res)

        for i in range(self.nr_slaves):
            assert self._queue.get() is True

        return results[0][1]

    @property
    def nr_slaves(self):
        return len(self._registry)


================================================
FILE: lib/sync_bn/replicate.py
================================================
# -*- coding: utf-8 -*-
# File   : replicate.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
# 
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import functools

from torch.nn.parallel.data_parallel import DataParallel

__all__ = [
    'CallbackContext',
    'execute_replication_callbacks',
    'DataParallelWithCallback',
    'patch_replication_callback'
]


class CallbackContext(object):
    pass


def execute_replication_callbacks(modules):
    """
    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.

    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`

    Note that, as all modules are isomorphism, we assign each sub-module with a context
    (shared among multiple copies of this module on different devices).
    Through this context, different copies can share some information.

    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
    of any slave copies.
    """
    master_copy = modules[0]
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]

    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):
                m.__data_parallel_replicate__(ctxs[j], i)


class DataParallelWithCallback(DataParallel):
    """
    Data Parallel with a replication callback.

    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
    original `replicate` functions.
    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
        # sync_bn.__data_parallel_replicate__ will be invoked.
    """

    def replicate(self, module, device_ids):
        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules


def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate


================================================
FILE: lib/sync_bn/unittest.py
================================================
# -*- coding: utf-8 -*-
# File   : unittest.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
# 
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import unittest

import numpy as np
from torch.autograd import Variable


def as_numpy(v):
    if isinstance(v, Variable):
        v = v.data
    return v.cpu().numpy()


class TorchTestCase(unittest.TestCase):
    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
        npa, npb = as_numpy(a), as_numpy(b)
        self.assertTrue(
                np.allclose(npa, npb, atol=atol),
                'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
        )


================================================
FILE: model/__init__.py
================================================


================================================
FILE: model/pointnet/pointnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class STN3D(nn.Module):
    def __init__(self, c):
        super(STN3D, self).__init__()
        self.c = c
        self.conv1 = nn.Conv1d(self.c, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.mp = nn.AdaptiveMaxPool1d(1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, self.c*self.c)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batch_size = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.mp(x)
        x = x.view(-1, 1024)
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = torch.eye(self.c).view(1, -1).repeat(batch_size, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.c, self.c)
        return x


class PointNetFeat(nn.Module):
    def __init__(self, c=3, global_feat=True):
        super(PointNetFeat, self).__init__()
        self.global_feat = global_feat
        self.stn1 = STN3D(c)
        self.conv1 = nn.Conv1d(c, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.stn2 = STN3D(64)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)
        self.mp = nn.AdaptiveMaxPool1d(1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)

    def forward(self, x):
        stn1 = self.stn1(x)
        x = torch.bmm(stn1, x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        stn2 = self.stn2(x)
        x_tmp = torch.bmm(stn2, x)
        x = F.relu(self.bn3(self.conv3(x_tmp)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.mp(x)
        x = x.view(-1, 1024)

        if not self.global_feat:
            x = x.view(-1, 1024, 1).repeat(1, 1, x_tmp.size()[2])
            x = torch.cat([x_tmp, x], 1)
        return x


class PointNetCls(nn.Module):
    def __init__(self, c=3, k=40, dropout=0.3, sync_bn=False):
        super(PointNetCls, self).__init__()
        self.feat = PointNetFeat(c, global_feat=True)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


# Segmentation with 9 channels input XYZ, RGB and normalized location to the room (from 0 to 1), with STN3D on input and feature
class PointNetSeg(nn.Module):
    def __init__(self, c=9, k=13, sync_bn=False):
        super(PointNetSeg, self).__init__()
        self.feat = PointNetFeat(c, global_feat=False)
        self.conv1 = nn.Conv1d(1088, 512, 1)
        self.conv2 = nn.Conv1d(512, 256, 1)
        self.conv3 = nn.Conv1d(256, 128, 1)
        self.conv4 = nn.Conv1d(128, 128, 1)
        self.conv5 = nn.Conv1d(128, k, 1)

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(128)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        return x


if __name__ == '__main__':
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    sim_data = torch.rand(16, 2048, 3)

    trans = STN3D(c=3)
    out = trans(sim_data.transpose(1, 2))
    print('stn', out.size())

    point_feat = PointNetFeat(global_feat=True)
    out = point_feat(sim_data.transpose(1, 2))
    print('global feat', out.size())

    point_feat = PointNetFeat(global_feat=False)
    out = point_feat(sim_data.transpose(1, 2))
    print('point feat', out.size())

    cls = PointNetCls(c=3, k=40)
    out = cls(sim_data)
    print('class', out.size())

    sim_data = torch.rand(16, 2048, 9)
    seg = PointNetSeg(c=9, k=13)
    out = seg(sim_data)
    print('seg', out.size())


================================================
FILE: model/pointnet2/pointnet2_modules.py
================================================
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.pointops.functions import pointops
from util import pt_util


class _PointNet2SAModuleBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.npoint = None
        self.groupers = None
        self.mlps = None

    def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, N, C) tensor of the descriptors of the the features
        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
        """
        new_features_list = []
        xyz_trans = xyz.transpose(1, 2).contiguous()
        new_xyz = pointops.gathering(
            xyz_trans,
            pointops.furthestsampling(xyz, self.npoint)
        ).transpose(1, 2).contiguous() if self.npoint is not None else None
        for i in range(len(self.groupers)):
            new_features = self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)])  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
            new_features_list.append(new_features)
        return new_xyz, torch.cat(new_features_list, dim=1)


class PointNet2SAModuleMSG(_PointNet2SAModuleBase):
    r"""Pointnet set abstrction layer with multiscale grouping
    Parameters
    ----------
    npoint : int
        Number of features
    radii : list of float32
        list of radii to group with
    nsamples : list of int32
        Number of samples in each ball query
    mlps : list of list of int32
        Spec of the pointnet_old before the global max_pool for each scale
    bn : bool
        Use batchnorm
    """
    def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, use_xyz: bool = True):
        super().__init__()
        assert len(radii) == len(nsamples) == len(mlps)
        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointops.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
                if npoint is not None else pointops.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3
            self.mlps.append(pt_util.SharedMLP(mlp_spec, bn=bn))


class PointNet2SAModule(PointNet2SAModuleMSG):
    r"""Pointnet set abstrction layer
    Parameters
    ----------
    npoint : int
        Number of features
    radius : float
        Radius of ball
    nsample : int
        Number of samples in the ball query
    mlp : list
        Spec of the pointnet_old before the global max_pool
    bn : bool
        Use batchnorm
    """
    def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, bn: bool = True, use_xyz: bool = True):
        super().__init__(mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz)


class PointNet2FPModule(nn.Module):
    r"""Propigates the features of one set to another
    Parameters
    ----------
    mlp : list
        Pointnet module parameters
    bn : bool
        Use batchnorm
    """
    def __init__(self, *, mlp: List[int], bn: bool = True):
        super().__init__()
        self.mlp = pt_util.SharedMLP(mlp, bn=bn)

    def forward(self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor) -> torch.Tensor:
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated
        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointops.nearestneighbor(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_feats = pointops.interpolation(known_feats, idx, weight)
        else:
            interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))

        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats], dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats
        return self.mlp(new_features.unsqueeze(-1)).squeeze(-1)


if __name__ == "__main__":
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    xyz = torch.randn(2, 9, 3, requires_grad=True).cuda()
    xyz_feats = torch.randn(2, 9, 6, requires_grad=True).cuda()

    test_module = PointNet2SAModuleMSG(npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]])
    test_module.cuda()
    print(test_module(xyz, xyz_feats))

    # test_module = PointNet2FPModule(mlp=[6, 6])
    # test_module.cuda()
    # from torch.autograd import gradcheck
    # inputs = (xyz, xyz, None, xyz_feats)
    # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4)
    # print(test)

    for _ in range(1):
        _, new_features = test_module(xyz, xyz_feats)
        new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1))
        print(new_features)
        print(xyz.grad)


================================================
FILE: model/pointnet2/pointnet2_seg.py
================================================
from collections import namedtuple

import torch
import torch.nn as nn

from model.pointnet2.pointnet2_modules import PointNet2SAModule, PointNet2SAModuleMSG, PointNet2FPModule
from util import pt_util


class PointNet2SSGSeg(nn.Module):
    r"""
        PointNet2 with single-scale grouping
        Semantic segmentation network that uses feature propogation layers
        Parameters
        ----------
        k: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        c: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """

    def __init__(self, c=3, k=13, use_xyz=True):
        super().__init__()
        self.SA_modules = nn.ModuleList()
        self.SA_modules.append(PointNet2SAModule(npoint=1024, nsample=32, mlp=[c, 32, 32, 64], use_xyz=use_xyz))
        self.SA_modules.append(PointNet2SAModule(npoint=256, nsample=32, mlp=[64, 64, 64, 128], use_xyz=use_xyz))
        self.SA_modules.append(PointNet2SAModule(npoint=64, nsample=32, mlp=[128, 128, 128, 256], use_xyz=use_xyz))
        self.SA_modules.append(PointNet2SAModule(npoint=16, nsample=32, mlp=[256, 256, 256, 512], use_xyz=use_xyz))
        self.FP_modules = nn.ModuleList()
        self.FP_modules.append(PointNet2FPModule(mlp=[128 + c, 128, 128, 128]))
        self.FP_modules.append(PointNet2FPModule(mlp=[256 + 64, 256, 128]))
        self.FP_modules.append(PointNet2FPModule(mlp=[256 + 128, 256, 256]))
        self.FP_modules.append(PointNet2FPModule(mlp=[512 + 256, 256, 256]))
        self.FC_layer = nn.Sequential(pt_util.Conv2d(128, 128, bn=True), nn.Dropout(), pt_util.Conv2d(128, k, activation=None))

    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = (pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None)
        return xyz, features

    def forward(self, pointcloud: torch.cuda.FloatTensor):
        r"""
            Forward pass of the network
            Parameters
            ----------
            pointcloud: Variable(torch.cuda.FloatTensor)
                (B, N, 3 + input_channels) tensor
                Point cloud to run predicts on
                Each point in the point-cloud MUST
                be formated as (x, y, z, features...)
        """
        xyz, features = self._break_up_pc(pointcloud)
        l_xyz, l_features = [xyz], [features]
        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)
        for i in range(-1, -(len(self.FP_modules) + 1), -1):
            l_features[i - 1] = self.FP_modules[i](l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i])
        # return self.FC_layer(l_features[0])
        return self.FC_layer(l_features[0].unsqueeze(-1)).squeeze(-1)


class PointNet2MSGSeg(PointNet2SSGSeg):
    r"""
        PointNet2 with multi-scale grouping
        Semantic segmentation network that uses feature propogation layers
        Parameters
        ----------
        k: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        c: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """

    def __init__(self, k, c=6, use_xyz=True):
        super().__init__()
        self.SA_modules = nn.ModuleList()
        c_in = c
        self.SA_modules.append(PointNet2SAModuleMSG(npoint=1024, radii=[0.05, 0.1], nsamples=[16, 32], mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]], use_xyz=use_xyz ))
        c_out_0 = 32 + 64
        c_in = c_out_0
        self.SA_modules.append(PointNet2SAModuleMSG(npoint=256, radii=[0.1, 0.2], nsamples=[16, 32], mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]], use_xyz=use_xyz))
        c_out_1 = 128 + 128
        c_in = c_out_1
        self.SA_modules.append(PointNet2SAModuleMSG(npoint=64, radii=[0.2, 0.4], nsamples=[16, 32], mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]], use_xyz=use_xyz))
        c_out_2 = 256 + 256
        c_in = c_out_2
        self.SA_modules.append(PointNet2SAModuleMSG(npoint=16, radii=[0.4, 0.8], nsamples=[16, 32], mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]], use_xyz=use_xyz))
        c_out_3 = 512 + 512
        self.FP_modules = nn.ModuleList()
        self.FP_modules.append(PointNet2FPModule(mlp=[256 + c, 128, 128]))
        self.FP_modules.append(PointNet2FPModule(mlp=[512 + c_out_0, 256, 256]))
        self.FP_modules.append(PointNet2FPModule(mlp=[512 + c_out_1, 512, 512]))
        self.FP_modules.append(PointNet2FPModule(mlp=[c_out_3 + c_out_2, 512, 512]))
        self.FC_layer = nn.Sequential(pt_util.Conv2d(128, 128, bn=True), nn.Dropout(), pt_util.Conv2d(128, k, activation=None))


def model_fn_decorator(criterion):
    ModelReturn = namedtuple("ModelReturn", ['preds', 'loss', 'acc'])

    def model_fn(model, data, eval=False):
        with torch.set_grad_enabled(not eval):
            inputs, labels = data
            inputs = inputs.cuda(async=True)
            labels = labels.cuda(async=True)
            preds = model(inputs)
            loss = criterion(preds, labels)
            _, classes = torch.max(preds, 1)
            acc = (classes == labels).float().sum() / labels.numel()
            return ModelReturn(preds, loss, {"acc": acc.item(), 'loss': loss.item()})
    return model_fn


if __name__ == "__main__":
    import torch.optim as optim
    B, N, C, K = 2, 4096, 3, 13
    inputs = torch.randn(B, N, 6).cuda()
    labels = torch.randint(0, 3, (B, N)).cuda()

    model = PointNet2SSGSeg(c=C, k=K).cuda()
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing SSGCls with xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()

    model = PointNet2SSGSeg(c=C, k=K, use_xyz=False).cuda()
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing SSGCls without xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()

    model = PointNet2MSGSeg(c=C, k=K).cuda()
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing MSGCls with xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()

    model = PointNet2MSGSeg(c=C, k=K, use_xyz=False).cuda()
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing MSGCls without xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()



================================================
FILE: model/pointweb/pointweb_module.py
================================================
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.pointops.functions import pointops
from util import pt_util


class _AFAModule(nn.Module):
    def __init__(self, mlp, use_softmax=False):
        r"""
        :param mlp: mlp for learning weight
               mode: transformation or aggregation
        """
        super().__init__()
        self.mlp = mlp
        self.use_softmax = use_softmax

    def forward(self, feature: torch.Tensor) -> torch.Tensor:
        r"""
        Parameters
        ----------
        features : torch.Tensor
            (B, C, N, M) or (B, C, N)
        Returns
        -------
        new_features : torch.Tensor
            transformation: (B, C, N, M) or (B, C, N)
            aggregation: (B, C, N) or (B, C)
        """
        B, C, N, M = feature.size()
        feature = feature.transpose(1, 2).contiguous().view(B * N, C, M, 1).repeat(1, 1, 1, M)  # (BN, C, M, M)
        feature = feature - feature.transpose(2, 3).contiguous() + torch.mul(feature, torch.eye(M).view(1, 1, M, M).cuda())  # (BN, C, M, M)
        weight = self.mlp(feature)
        if self.use_softmax:
            weight = F.softmax(weight, -1)
        feature = (feature * weight).sum(-1).view(B, N, C, M).transpose(1, 2).contiguous()  # (B, C, N, M)
        return feature


class _PointWebSAModuleBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.npoint = None
        self.grouper = None
        self.mlp = None
        self.afa = None

    def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features
        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
        """
        new_features_list = []
        xyz_trans = xyz.transpose(1, 2).contiguous()
        new_xyz = pointops.gathering(
            xyz_trans,
            pointops.furthestsampling(xyz, self.npoint)
        ).transpose(1, 2).contiguous() if self.npoint is not None else None
        new_features = self.grouper(xyz, new_xyz, features)  # (B, C, npoint, nsample)
        if new_features.shape[2] != 1:  # for npoint is none
            new_features = new_features + self.afa(new_features)  # (B, C, npoint, nsample)
        new_features = self.mlp(new_features)
        new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]).squeeze(-1)  # (B, mlp[-1], npoint)
        new_features_list.append(new_features)
        return new_xyz, torch.cat(new_features_list, dim=1)


class PointWebSAModule(_PointWebSAModuleBase):
    r"""Pointnet set abstrction layer with multiscale grouping
    Parameters
    ----------
    npoint : int
        Number of features
    nsample : int32
        Number of sample
    mlps : list of int32
        Spec of the MLP before the global max_pool
    mlps2: list of list of int32
        Spec of the MLP for AFA
    bn : bool
        Use batchnorm
    """
    def __init__(self, *, npoint: int = None, nsample: int = None, mlp: List[int] = None, mlp2: List[int] = None, bn: bool = True, use_xyz: bool = True, use_bn = True):
        super().__init__()
        self.npoint = npoint
        self.grouper = pointops.QueryAndGroup(nsample=nsample, use_xyz=use_xyz) if npoint is not None else pointops.GroupAll(use_xyz)
        if use_xyz:
            mlp[0] += 3
        if npoint is not None:
            mlp_tmp = pt_util.SharedMLP([mlp[0]] + mlp2, bn=use_bn)
            mlp_tmp.add_module('weight', (pt_util.SharedMLP([mlp2[-1], mlp[0]], bn=False, activation=None)))
            self.afa = _AFAModule(mlp=mlp_tmp)
        self.mlp = pt_util.SharedMLP(mlp, bn=bn)


if __name__ == "__main__":
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    c = 6
    xyz = torch.randn(2, 8, 3, requires_grad=True).cuda()
    xyz_feats = torch.randn(2, 8, c, requires_grad=True).cuda()

    test_module = PointWebSAModule(npoint=2, nsample=6, mlp=[c, 32, 32], mlp2=[16, 16], use_bn=True)
    test_module.cuda()
    xyz_feats = xyz_feats.transpose(1, 2).contiguous()
    print(test_module)
    print(test_module(xyz, xyz_feats))

    for _ in range(1):
        _, new_features = test_module(xyz, xyz_feats)
        new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1))
        print(new_features)
        print(xyz.grad)


================================================
FILE: model/pointweb/pointweb_seg.py
================================================
from collections import namedtuple

import torch
import torch.nn as nn

from model.pointweb.pointweb_module import PointWebSAModule
from model.pointnet2.pointnet2_modules import PointNet2FPModule
from util import pt_util


class PointWebSeg(nn.Module):
    r"""
        PointNet2 with single-scale grouping
        Semantic segmentation network that uses feature propogation layers
        Parameters
        ----------
        k: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        c: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """

    def __init__(self, c=3, k=13, use_xyz=True):
        super().__init__()
        self.SA_modules = nn.ModuleList()
        self.SA_modules.append(PointWebSAModule(npoint=1024, nsample=32, mlp=[c, 32, 32, 64], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=256, nsample=32, mlp=[64, 64, 64, 128], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=64, nsample=32, mlp=[128, 128, 128, 256], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=16, nsample=32, mlp=[256, 256, 256, 512], mlp2=[32, 32], use_xyz=use_xyz))

        self.FP_modules = nn.ModuleList()
        self.FP_modules.append(PointNet2FPModule(mlp=[128 + c, 128, 128, 128]))
        self.FP_modules.append(PointNet2FPModule(mlp=[256 + 64, 256, 128]))
        self.FP_modules.append(PointNet2FPModule(mlp=[256 + 128, 256, 256]))
        self.FP_modules.append(PointNet2FPModule(mlp=[512 + 256, 256, 256]))
        self.FC_layer = nn.Sequential(pt_util.Conv2d(128, 128, bn=True), nn.Dropout(), pt_util.Conv2d(128, k, activation=None))

    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = (pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None)
        return xyz, features

    def forward(self, pointcloud: torch.cuda.FloatTensor):
        r"""
            Forward pass of the network
            Parameters
            ----------
            pointcloud: Variable(torch.cuda.FloatTensor)
                (B, N, 3 + input_channels) tensor
                Point cloud to run predicts on
                Each point in the point-cloud MUST
                be formated as (x, y, z, features...)
        """
        xyz, features = self._break_up_pc(pointcloud)
        l_xyz, l_features = [xyz], [features]
        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)
        for i in range(-1, -(len(self.FP_modules) + 1), -1):
            l_features[i - 1] = self.FP_modules[i](l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i])
        return self.FC_layer(l_features[0].unsqueeze(-1)).squeeze(-1)


def model_fn_decorator(criterion):
    ModelReturn = namedtuple("ModelReturn", ['preds', 'loss', 'acc'])

    def model_fn(model, data, epoch=0, eval=False):
        with torch.set_grad_enabled(not eval):
            inputs, labels = data
            inputs = inputs.cuda(async=True)
            labels = labels.cuda(async=True)
            preds = model(inputs)
            loss = criterion(preds, labels)
            _, classes = torch.max(preds, 1)
            acc = (classes == labels).float().sum() / labels.numel()
            return ModelReturn(preds, loss, {"acc": acc.item(), 'loss': loss.item()})
    return model_fn


def model_fn_decorator(criterion):
    ModelReturn = namedtuple("ModelReturn", ['preds', 'loss', 'acc'])

    def model_fn(model, data, eval=False):
        with torch.set_grad_enabled(not eval):
            inputs, labels = data
            inputs = inputs.cuda(async=True)
            labels = labels.cuda(async=True)
            preds = model(inputs)
            loss = criterion(preds, labels)
            _, classes = torch.max(preds, 1)
            acc = (classes == labels).float().sum() / labels.numel()
            return ModelReturn(preds, loss, {"acc": acc.item(), 'loss': loss.item()})
    return model_fn


if __name__ == "__main__":
    import torch.optim as optim
    B, N, C, K = 2, 4096, 3, 13
    inputs = torch.randn(B, N, 6).cuda()
    labels = torch.randint(0, 3, (B, N)).cuda()

    model = PointWebSeg(c=C, k=K).cuda()
    print(model)
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing SSGCls with xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()

    model = PointWebSeg(c=C, k=K, use_xyz=False).cuda()
    print(model)
    optimizer = optim.SGD(model.parameters(), lr=5e-2, momentum=0.9, weight_decay=1e-4)
    print("Testing SSGCls without xyz")
    model_fn = model_fn_decorator(nn.CrossEntropyLoss())
    for _ in range(5):
        optimizer.zero_grad()
        _, loss, _ = model_fn(model, (inputs, labels))
        loss.backward()
        print(loss.item())
        optimizer.step()


================================================
FILE: tool/test.sh
================================================
#!/bin/sh
export PYTHONPATH=./

PYTHON=python
dataset=$1
exp_name=$2
exp_dir=exp/${dataset}/${exp_name}
model_dir=${exp_dir}/model
config=config/${dataset}/${dataset}_${exp_name}.yaml

mkdir -p ${model_dir}
now=$(date +"%Y%m%d_%H%M%S")

if [ ${dataset} = 's3dis' ]
then
  cp tool/test.sh tool/test_s3dis.py ${config} ${exp_dir}
  $PYTHON tool/test_s3dis.py --config=${config} 2>&1 | tee ${model_dir}/test-$now.log
elif [ ${dataset} = 'scannet' ]
then
  cp tool/test.sh tool/test_scannet.py ${config} ${exp_dir}
  $PYTHON tool/test_scannet.py --config=${config} 2>&1 | tee ${model_dir}/test-$now.log
fi


================================================
FILE: tool/test_s3dis.py
================================================
import os
import time
import random
import numpy as np
import logging
import pickle
import argparse

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data

from util import config
from util.util import AverageMeter, intersectionAndUnion, check_makedirs

random.seed(123)
np.random.seed(123)


def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch Point Cloud Classification / Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/s3dis/s3dis_pointweb.yaml', help='config file')
    parser.add_argument('opts', help='see config/s3dis/s3dis_pointweb.yaml for all options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg


def get_logger():
    logger_name = "main-logger"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    handler.setFormatter(logging.Formatter(fmt))
    logger.addHandler(handler)
    return logger


def main():
    global args, logger
    args = get_parser()
    logger = get_logger()
    logger.info(args)
    assert args.classes > 1
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    model = torch.nn.DataParallel(model.cuda())
    logger.info(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    names = [line.rstrip('\n') for line in open(args.names_path)]
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    test(model, criterion, names)


def data_prepare(room_path):
    room_data = np.load(room_path)
    points, labels = room_data[:, 0:6], room_data[:, 6]  # xyzrgb, N*6; l, N
    coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
    stride = args.block_size * args.stride_rate
    grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - args.block_size) / stride) + 1)
    grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - args.block_size) / stride) + 1)
    data_room, label_room, index_room = np.array([]), np.array([]), np.array([])
    for index_y in range(0, grid_y):
        for index_x in range(0, grid_x):
            s_x = coord_min[0] + index_x * stride
            e_x = min(s_x + args.block_size, coord_max[0])
            s_x = e_x - args.block_size
            s_y = coord_min[1] + index_y * stride
            e_y = min(s_y + args.block_size, coord_max[1])
            s_y = e_y - args.block_size
            point_idxs = np.where((points[:, 0] >= s_x - 1e-8) & (points[:, 0] <= e_x + 1e-8) & (points[:, 1] >= s_y - 1e-8) & (points[:, 1] <= e_y + 1e-8))[0]
            if point_idxs.size == 0:
                continue
            num_batch = int(np.ceil(point_idxs.size / args.num_point))
            point_size = int(num_batch * args.num_point)
            replace = False if (point_size - point_idxs.size <= point_idxs.size) else True
            point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace)
            point_idxs = np.concatenate((point_idxs, point_idxs_repeat))
            np.random.shuffle(point_idxs)
            data_batch = points[point_idxs, :]
            normlized_xyz = np.zeros((point_size, 3))
            normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0]
            normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1]
            normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2]
            data_batch[:, 0] = data_batch[:, 0] - (s_x + args.block_size / 2.0)
            data_batch[:, 1] = data_batch[:, 1] - (s_y + args.block_size / 2.0)
            data_batch[:, 3:6] /= 255.0
            data_batch = np.concatenate((data_batch, normlized_xyz), axis=1)
            label_batch = labels[point_idxs]
            data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch
            label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch
            index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs
    assert np.unique(index_room).size == labels.size
    return data_room, label_room, index_room, labels


def test(model, criterion, names):
    logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.eval()
    rooms = sorted(os.listdir(args.train_full_folder))
    rooms_split = [room for room in rooms if 'Area_{}'.format(args.test_area) in room]
    gt_all, pred_all = np.array([]), np.array([])
    check_makedirs(args.save_folder)
    pred_save, gt_save = [], []
    for idx, room_name in enumerate(rooms_split):
        data_room, label_room, index_room, gt = data_prepare(os.path.join(args.train_full_folder, room_name))
        batch_point = args.num_point * args.test_batch_size
        batch_num = int(np.ceil(label_room.size / batch_point))
        end = time.time()
        output_room = np.array([])
        for i in range(batch_num):
            s_i, e_i = i * batch_point, min((i + 1) * batch_point, label_room.size)
            input, target, index = data_room[s_i:e_i, :], label_room[s_i:e_i], index_room[s_i:e_i]
            input = torch.from_numpy(input).float().view(-1, args.num_point, input.shape[1])
            target = torch.from_numpy(target).long().view(-1, args.num_point)
            with torch.no_grad():
                output = model(input.cuda())
            loss = criterion(output, target.cuda())  # for reference
            output = output.transpose(1, 2).contiguous().view(-1, args.classes).data.cpu().numpy()
            pred = np.argmax(output, axis=1)
            intersection, union, target = intersectionAndUnion(pred, target.view(-1).data.cpu().numpy(), args.classes, args.ignore_label)
            accuracy = sum(intersection) / (sum(target) + 1e-10)
            output_room = np.vstack([output_room, output]) if output_room.size else output
            batch_time.update(time.time() - end)
            end = time.time()
            if ((i + 1) % args.print_freq == 0) or (i + 1 == batch_num):
                logger.info('Test: [{}/{}]-[{}/{}] '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Loss {loss:.4f} '
                            'Accuracy {accuracy:.4f} '
                            'Points {gt.size}.'.format(idx + 1, len(rooms_split),
                                                       i + 1, batch_num,
                                                       batch_time=batch_time,
                                                       loss=loss,
                                                       accuracy=accuracy,
                                                       gt=gt))
        '''
        unq, unq_inv, unq_cnt = np.unique(index_room, return_inverse=True, return_counts=True)
        index_array = np.split(np.argsort(unq_inv), np.cumsum(unq_cnt[:-1]))
        output_room = np.vstack([output_room, np.zeros((1, args.classes))])
        index_array_fill = np.array(list(itertools.zip_longest(*index_array, fillvalue=output_room.shape[0] - 1))).T
        pred = output_room[index_array_fill].sum(1)
        pred = np.argmax(pred, axis=1)
        '''
        pred = np.zeros((gt.size, args.classes))
        for j in range(len(index_room)):
            pred[index_room[j]] += output_room[j]
        pred = np.argmax(pred, axis=1)

        # calculation 1: add per room predictions
        intersection, union, target = intersectionAndUnion(pred, gt, args.classes, args.ignore_label)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)
        # calculation 2
        pred_all = np.hstack([pred_all, pred]) if pred_all.size else pred
        gt_all = np.hstack([gt_all, gt]) if gt_all.size else gt
        pred_save.append(pred), gt_save.append(gt)

    with open(os.path.join(args.save_folder, "pred_{}.pickle".format(args.test_area)), 'wb') as handle:
        pickle.dump({'pred': pred_save}, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.save_folder, "gt_{}.pickle".format(args.test_area)), 'wb') as handle:
        pickle.dump({'gt': gt_save}, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # calculation 1
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU1 = np.mean(iou_class)
    mAcc1 = np.mean(accuracy_class)
    allAcc1 = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    # calculation 2
    intersection, union, target = intersectionAndUnion(pred_all, gt_all, args.classes, args.ignore_label)
    iou_class = intersection / (union + 1e-10)
    accuracy_class = intersection / (target + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection) / (sum(target) + 1e-10)
    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
    logger.info('Val1 result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU1, mAcc1, allAcc1))

    for i in range(args.classes):
        logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}, name: {}.'.format(i, iou_class[i], accuracy_class[i], names[i]))
    logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return mIoU, mAcc, allAcc, pred_all


if __name__ == '__main__':
    main()


================================================
FILE: tool/test_scannet.py
================================================
import os
import time
import random
import numpy as np
import logging
import pickle
import argparse

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data

from util import config
from util.util import AverageMeter, intersectionAndUnion, check_makedirs

random.seed(123)
np.random.seed(123)


def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch Point Cloud Classification / Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/scannet/scannet_pointweb.yaml', help='config file')
    parser.add_argument('opts', help='see config/scannet/scannet_pointweb.yaml for all options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg


def get_logger():
    logger_name = "main-logger"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    handler.setFormatter(logging.Formatter(fmt))
    logger.addHandler(handler)
    return logger


def main():
    global args, logger
    args = get_parser()
    logger = get_logger()
    logger.info(args)
    assert args.classes > 1
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))

    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    model = torch.nn.DataParallel(model.cuda())
    logger.info(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    names = [line.rstrip('\n') for line in open(args.names_path)]
    if os.path.isfile(args.model_path):
        logger.info("=> loading checkpoint '{}'".format(args.model_path))
        checkpoint = torch.load(args.model_path)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        logger.info("=> loaded checkpoint '{}'".format(args.model_path))
    else:
        raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path))
    test(model, criterion, names)


def data_prepare(points, labels):
    coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
    stride = args.block_size * args.stride_rate
    grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - args.block_size) / stride) + 1)
    grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - args.block_size) / stride) + 1)
    data_room, label_room, index_room = np.array([]), np.array([]), np.array([])
    for index_y in range(0, grid_y):
        for index_x in range(0, grid_x):
            s_x = coord_min[0] + index_x * stride
            e_x = min(s_x + args.block_size, coord_max[0])
            s_x = e_x - args.block_size
            s_y = coord_min[1] + index_y * stride
            e_y = min(s_y + args.block_size, coord_max[1])
            s_y = e_y - args.block_size
            point_idxs = np.where((points[:, 0] >= s_x - 1e-8) & (points[:, 0] <= e_x + 1e-8) & (points[:, 1] >= s_y - 1e-8) & (points[:, 1] <= e_y + 1e-8))[0]
            if point_idxs.size == 0:
                continue
            num_batch = int(np.ceil(point_idxs.size / args.num_point))
            point_size = int(num_batch * args.num_point)
            replace = False if (point_size - point_idxs.size <= point_idxs.size) else True
            point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace)
            point_idxs = np.concatenate((point_idxs, point_idxs_repeat))
            np.random.shuffle(point_idxs)
            data_batch = points[point_idxs, :]
            normlized_xyz = np.zeros((point_size, 3))
            normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0]
            normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1]
            normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2]
            data_batch[:, 0] = data_batch[:, 0] - (s_x + args.block_size / 2.0)
            data_batch[:, 1] = data_batch[:, 1] - (s_y + args.block_size / 2.0)
            data_batch = np.concatenate((data_batch, normlized_xyz), axis=1)
            label_batch = labels[point_idxs]
            data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch
            label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch
            index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs
    assert np.unique(index_room).size == labels.size
    return data_room, label_room, index_room


def test(model, criterion, names):
    logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.eval()
    data_file = os.path.join(args.data_root, 'scannet_{}.pickle'.format(args.split))
    file_pickle = open(data_file, 'rb')
    xyz_all = pickle.load(file_pickle, encoding='latin1')
    label_all = pickle.load(file_pickle, encoding='latin1')
    file_pickle.close()
    gt_all, pred_all = np.array([]), np.array([])
    vox_acc = []
    check_makedirs(args.save_folder)
    pred_save, gt_save = [], []
    for idx in range(len(xyz_all)):
        points, labels = xyz_all[idx], label_all[idx].astype(np.int32)
        gt = labels - 1
        gt[labels == 0] = 255
        data_room, label_room, index_room = data_prepare(points, gt)
        batch_point = args.num_point * args.test_batch_size
        batch_num = int(np.ceil(label_room.size / batch_point))
        end = time.time()
        output_room = np.array([])
        for i in range(batch_num):
            s_i, e_i = i * batch_point, min((i + 1) * batch_point, label_room.size)
            input, target, index = data_room[s_i:e_i, :], label_room[s_i:e_i], index_room[s_i:e_i]
            input = torch.from_numpy(input).float().view(-1, args.num_point, input.shape[1])
            target = torch.from_numpy(target).long().view(-1, args.num_point)
            with torch.no_grad():
                output = model(input.cuda())
            loss = criterion(output, target.cuda())  # for reference
            output = output.transpose(1, 2).contiguous().view(-1, args.classes).data.cpu().numpy()
            pred = np.argmax(output, axis=1)
            intersection, union, target = intersectionAndUnion(pred, target.view(-1).data.cpu().numpy(), args.classes,
                                                               args.ignore_label)
            accuracy = sum(intersection) / (sum(target) + 1e-10)
            output_room = np.vstack([output_room, output]) if output_room.size else output
            batch_time.update(time.time() - end)
            end = time.time()
            if ((i + 1) % args.print_freq == 0) or (i + 1 == batch_num):
                logger.info('Test: [{}/{}]-[{}/{}] '
                            'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                            'Loss {loss:.4f} '
                            'Accuracy {accuracy:.4f} '
                            'Points {gt.size}.'.format(idx + 1, len(xyz_all),
                                                       i + 1, batch_num,
                                                       batch_time=batch_time,
                                                       loss=loss,
                                                       accuracy=accuracy,
                                                       gt=gt))

        pred = np.zeros((gt.size, args.classes))
        for j in range(len(index_room)):
            pred[index_room[j]] += output_room[j]
        pred = np.argmax(pred, axis=1)

        # calculation 1: add per room predictions
        intersection, union, target = intersectionAndUnion(pred, gt, args.classes, args.ignore_label)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)
        # calculation 2
        pred_all = np.hstack([pred_all, pred]) if pred_all.size else pred
        gt_all = np.hstack([gt_all, gt]) if gt_all.size else gt
        pred_save.append(pred), gt_save.append(gt)

        # compute voxel accuracy (follow scannet, pointnet++ and pointcnn)
        res = 0.0484
        coord_min, coord_max = np.min(points, axis=0), np.max(points, axis=0)
        nvox = np.ceil((coord_max - coord_min) / res)
        vidx = np.ceil((points - coord_min) / res)
        vidx = vidx[:, 0] + vidx[:, 1] * nvox[0] + vidx[:, 2] * nvox[0] * nvox[1]
        uvidx, vpidx = np.unique(vidx, return_index=True)
        # compute voxel label
        uvlabel = np.array(gt)[vpidx]
        uvpred = np.array(pred)[vpidx]
        # compute voxel accuracy (ignore label 0 which is scannet unannotated)
        c_accvox = np.sum(np.equal(uvpred, uvlabel))
        c_ignore = np.sum(np.equal(uvlabel, 255))
        vox_acc.append([c_accvox, len(uvlabel) - c_ignore])

    with open(os.path.join(args.save_folder, "pred_{}.pickle".format(args.split)), 'wb') as handle:
        pickle.dump({'pred': pred_save}, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.save_folder, "gt_{}.pickle".format(args.split)), 'wb') as handle:
        pickle.dump({'gt': gt_save}, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # calculation 1
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU1 = np.mean(iou_class)
    mAcc1 = np.mean(accuracy_class)
    allAcc1 = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    # calculation 2
    intersection, union, target = intersectionAndUnion(pred_all, gt_all, args.classes, args.ignore_label)
    iou_class = intersection / (union + 1e-10)
    accuracy_class = intersection / (target + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection) / (sum(target) + 1e-10)
    # compute avg voxel acc
    vox_acc = np.sum(vox_acc, 0)
    voxAcc = vox_acc[0] * 1.0 / vox_acc[1]
    logger.info('Val result: mIoU/mAcc/allAcc/voxAcc {:.4f}/{:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc, voxAcc))
    logger.info('Val111 result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}/{:.4f}.'.format(mIoU1, mAcc1, allAcc1, voxAcc))

    for i in range(args.classes):
        logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}, name: {}.'.format(i, iou_class[i], accuracy_class[i],
                                                                                    names[i]))
    logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return mIoU, mAcc, allAcc, pred_all


if __name__ == '__main__':
    main()


================================================
FILE: tool/train.py
================================================
import os
import time
import random
import numpy as np
import logging
import argparse

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.optim.lr_scheduler as lr_scheduler
from tensorboardX import SummaryWriter

from util import dataset, transform, config
from util.s3dis import S3DIS
from util.scannet import ScanNet
from util.util import AverageMeter, intersectionAndUnionGPU


def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch Point Cloud Semantic Segmentation')
    parser.add_argument('--config', type=str, default='config/s3dis/s3dis_pointweb.yaml', help='config file')
    parser.add_argument('opts', help='see config/s3dis/s3dis_pointweb.yaml for all options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()
    assert args.config is not None
    cfg = config.load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = config.merge_cfg_from_list(cfg, args.opts)
    return cfg


def get_logger():
    logger_name = "main-logger"
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
    handler.setFormatter(logging.Formatter(fmt))
    logger.addHandler(handler)
    return logger


def worker_init_fn(worker_id):
    random.seed(args.manual_seed + worker_id)


def init():
    global args, logger, writer
    args = get_parser()
    logger = get_logger()
    writer = SummaryWriter(args.save_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu)
    if args.manual_seed is not None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.manual_seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed_all(args.manual_seed)
    if len(args.train_gpu) == 1:
        args.sync_bn = False
    logger.info(args)


def main():
    init()
    if args.arch == 'pointnet_seg':
        from model.pointnet.pointnet import PointNetSeg as Model
    elif args.arch == 'pointnet2_seg':
        from model.pointnet2.pointnet2_seg import PointNet2SSGSeg as Model
    elif args.arch == 'pointweb_seg':
        from model.pointweb.pointweb_seg import PointWebSeg as Model
    else:
        raise Exception('architecture not supported yet'.format(args.arch))
    model = Model(c=args.fea_dim, k=args.classes, use_xyz=args.use_xyz)
    if args.sync_bn:
        from util.util import convert_to_syncbn
        convert_to_syncbn(model)
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_epoch, gamma=args.multiplier)
    logger.info("=> creating model ...")
    logger.info("Classes: {}".format(args.classes))
    logger.info(model)
    model = torch.nn.DataParallel(model.cuda())
    if args.sync_bn:
        from lib.sync_bn import patch_replication_callback
        patch_replication_callback(model)
    if args.weight:
        if os.path.isfile(args.weight):
            logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    train_transform = transform.Compose([transform.ToTensor()])
    if args.data_name == 's3dis':
        train_data = S3DIS(split='train', data_root=args.train_full_folder, num_point=args.num_point, test_area=args.test_area, block_size=args.block_size, sample_rate=args.sample_rate, transform=train_transform)
        # train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    elif args.data_name == 'scannet':
        train_data = ScanNet(split='train', data_root=args.data_root, num_point=args.num_point, block_size=args.block_size, sample_rate=args.sample_rate, transform=train_transform)
    elif args.data_name == 'modelnet40':
        train_data = dataset.PointData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform, num_point=args.num_point, random_index=True)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True, num_workers=args.train_workers, pin_memory=True)

    val_loader = None
    if args.evaluate:
        val_transform = transform.Compose([transform.ToTensor()])
        val_data = dataset.PointData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.train_batch_size_val, shuffle=False, num_workers=args.train_workers, pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch)
        epoch_log = epoch + 1
        writer.add_scalar('loss_train', loss_train, epoch_log)
        writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
        writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
        writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if epoch_log % args.save_freq == 0:
            filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion)
            writer.add_scalar('loss_val', loss_val, epoch_log)
            writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
            writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
            writer.add_scalar('allAcc_val', allAcc_val, epoch_log)


def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        output = model(input)
        if target.shape[-1] == 1:
            target = target[:, 0]  # for cls
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        loss_meter.update(loss.item(), input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        current_iter = epoch * len(train_loader) + i + 1
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if (i + 1) % args.print_freq == 0:
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time, data_time=data_time,
                                                          remain_time=remain_time,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))
        writer.add_scalar('loss_train_batch', loss_meter.val, current_iter)
        writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
        writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
        writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc))
    return loss_meter.avg, mIoU, mAcc, allAcc


def validate(val_loader, model, criterion):
    logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        if target.shape[-1] == 1:
            target = target[:, 0]  # for cls
        output = model(input)
        loss = criterion(output, target)

        output = output.max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        loss_meter.update(loss.item(), input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if (i + 1) % args.print_freq == 0:
            logger.info('Test: [{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) '
                        'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader),
                                                          data_time=data_time,
                                                          batch_time=batch_time,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
    for i in range(args.classes):
        logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))
    logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
    return loss_meter.avg, mIoU, mAcc, allAcc


if __name__ == '__main__':
    main()


================================================
FILE: tool/train.sh
================================================
#!/bin/sh
export PYTHONPATH=./

PYTHON=python
dataset=$1
exp_name=$2
exp_dir=exp/${dataset}/${exp_name}
model_dir=${exp_dir}/model
config=config/${dataset}/${dataset}_${exp_name}.yaml

mkdir -p ${model_dir}
now=$(date +"%Y%m%d_%H%M%S")
cp tool/train.sh tool/train.py ${config} ${exp_dir}

$PYTHON tool/train.py --config=${config} 2>&1 | tee ${model_dir}/train-$now.log

if [ ${dataset} = 's3dis' ]
then
  $PYTHON tool/test_s3dis.py --config=${config} 2>&1 | tee ${model_dir}/test-$now.log
elif [ ${dataset} = 'scannet' ]
then
  $PYTHON tool/test_scannet.py --config=${config} 2>&1 | tee ${model_dir}/test-$now.log
fi


================================================
FILE: util/config.py
================================================
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import yaml
import os
from ast import literal_eval
import copy


class CfgNode(dict):
    """
    CfgNode represents an internal node in the configuration tree. It's a simple
    dict-like container that allows for attribute-based access to keys.
    """

    def __init__(self, init_dict=None, key_list=None, new_allowed=False):
        # Recursively convert nested dictionaries in init_dict into CfgNodes
        init_dict = {} if init_dict is None else init_dict
        key_list = [] if key_list is None else key_list
        for k, v in init_dict.items():
            if type(v) is dict:
                # Convert dict to CfgNode
                init_dict[k] = CfgNode(v, key_list=key_list + [k])
        super(CfgNode, self).__init__(init_dict)

    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        self[name] = value

    def __str__(self):
        def _indent(s_, num_spaces):
            s = s_.split("\n")
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * " ") + line for line in s]
            s = "\n".join(s)
            s = first + "\n" + s
            return s

        r = ""
        s = []
        for k, v in sorted(self.items()):
            seperator = "\n" if isinstance(v, CfgNode) else " "
            attr_str = "{}:{}{}".format(str(k), seperator, str(v))
            attr_str = _indent(attr_str, 2)
            s.append(attr_str)
        r += "\n".join(s)
        return r

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())


def load_cfg_from_cfg_file(file):
    cfg = {}
    assert os.path.isfile(file) and file.endswith('.yaml'), \
        '{} is not a yaml file'.format(file)

    with open(file, 'r') as f:
        cfg_from_file = yaml.safe_load(f)

    for key in cfg_from_file:
        for k, v in cfg_from_file[key].items():
            cfg[k] = v

    cfg = CfgNode(cfg)
    return cfg


def merge_cfg_from_list(cfg, cfg_list):
    new_cfg = copy.deepcopy(cfg)
    assert len(cfg_list) % 2 == 0
    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
        subkey = full_key.split('.')[-1]
        assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
        value = _decode_cfg_value(v)
        value = _check_and_coerce_cfg_value_type(
            value, cfg[subkey], subkey, full_key
        )
        setattr(new_cfg, subkey, value)

    return new_cfg


def _decode_cfg_value(v):
    """Decodes a raw config value (e.g., from a yaml config files or command
    line argument) into a Python object.
    """
    # All remaining processing is only applied to strings
    if not isinstance(v, str):
        return v
    # Try to interpret `v` as a:
    #   string, number, tuple, list, dict, boolean, or None
    try:
        v = literal_eval(v)
    # The following two excepts allow v to pass through when it represents a
    # string.
    #
    # Longer explanation:
    # The type of v is always a string (before calling literal_eval), but
    # sometimes it *represents* a string and other times a data structure, like
    # a list. In the case that v represents a string, what we got back from the
    # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
    # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
    # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
    # will raise a SyntaxError.
    except ValueError:
        pass
    except SyntaxError:
        pass
    return v


def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
    """Checks that `replacement`, which is intended to replace `original` is of
    the right type. The type is correct if it matches exactly or is one of a few
    cases in which the type can be easily coerced.
    """
    original_type = type(original)
    replacement_type = type(replacement)

    # The types must match (with some exceptions)
    if replacement_type == original_type:
        return replacement

    # Cast replacement from from_type to to_type if the replacement and original
    # types match from_type and to_type
    def conditional_cast(from_type, to_type):
        if replacement_type == from_type and original_type == to_type:
            return True, to_type(replacement)
        else:
            return False, None

    # Conditionally casts
    # list <-> tuple
    casts = [(tuple, list), (list, tuple)]
    # For py2: allow converting from str (bytes) to a unicode string
    try:
        casts.append((str, unicode))  # noqa: F821
    except Exception:
        pass

    for (from_type, to_type) in casts:
        converted, converted_value = conditional_cast(from_type, to_type)
        if converted:
            return converted_value

    raise ValueError(
        "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
        "key: {}".format(
            original_type, replacement_type, original, replacement, full_key
        )
    )


def _assert_with_logging(cond, msg):
    if not cond:
        logger.debug(msg)
    assert cond, msg



================================================
FILE: util/dataset.py
================================================
import os
import h5py
import numpy as np

from torch.utils.data import Dataset


def make_dataset(split='train', data_root=None, data_list=None):
    if not os.path.isfile(data_list):
        raise (RuntimeError("Point list file do not exist: " + data_list + "\n"))
    point_list = []
    list_read = open(data_list).readlines()
    print("Totally {} samples in {} set.".format(len(list_read), split))
    for line in list_read:
        point_list.append(os.path.join(data_root, line.strip()))
    return point_list


class PointData(Dataset):
    def __init__(self, split='train', data_root=None, data_list=None, transform=None, num_point=None, random_index=False):
        assert split in ['train', 'val', 'test']
        self.split = split
        self.data_list = make_dataset(split, data_root, data_list)
        self.transform = transform
        self.num_point = num_point
        self.random_index = random_index

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        data_path = self.data_list[index]
        f = h5py.File(data_path, 'r')
        data = f['data'][:]
        if self.split is 'test':
            label = 255  # place holder
        else:
            label = f['label'][:]
        f.close()
        if self.num_point is None:
            self.num_point = data.shape[0]
        idxs = np.arange(data.shape[0])
        if self.random_index:
            np.random.shuffle(idxs)
        idxs = idxs[0:self.num_point]
        data = data[idxs, :]
        if label.size != 1:  # seg data
            label = label[idxs]
        if self.transform is not None:
            data, label = self.transform(data, label)
        return data, label


if __name__ == '__main__':
    data_root = '/mnt/sda1/hszhao/dataset/3d/s3dis'
    data_list = '/mnt/sda1/hszhao/dataset/3d/s3dis/list/train12346.txt'
    point_data = PointData('train', data_root, data_list)
    print('point data size:', point_data.__len__())
    print('point data 0 shape:', point_data.__getitem__(0)[0].shape)
    print('point label 0 shape:', point_data.__getitem__(0)[1].shape)


================================================
FILE: util/pt_util.py
================================================
import shutil, os
import tqdm
from itertools import repeat
import numpy as np
from typing import List, Tuple
# from scipy.stats import t as student_t
# import statistics as stats

import torch
import torch.nn as nn
from torch.autograd.function import InplaceFunction

BN1d, BN2d, BN3d = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d

class SharedMLP(nn.Sequential):

    def __init__(
            self,
            args: List[int],
            *,
            bn: bool = False,
            activation=nn.ReLU(inplace=True),
            preact: bool = False,
            first: bool = False,
            name: str = ""
    ):
        super().__init__()

        for i in range(len(args) - 1):
            self.add_module(
                name + 'layer{}'.format(i),
                Conv2d(
                    args[i],
                    args[i + 1],
                    bn=(not first or not preact or (i != 0)) and bn,
                    activation=activation
                    if (not first or not preact or (i != 0)) else None,
                    preact=preact
                )
            )


class _BNBase(nn.Sequential):

    def __init__(self, in_size, batch_norm=None, name=""):
        super().__init__()
        self.add_module(name + "bn", batch_norm(in_size))

        nn.init.constant_(self[0].weight, 1.0)
        nn.init.constant_(self[0].bias, 0)


class BatchNorm1d(_BNBase):

    def __init__(self, in_size: int, *, name: str = ""):
        super().__init__(in_size, batch_norm=BN1d, name=name)


class BatchNorm2d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=BN2d, name=name)


class BatchNorm3d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=BN3d, name=name)


class _ConvBase(nn.Sequential):

    def __init__(
            self,
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=None,
            batch_norm=None,
            bias=True,
            preact=False,
            name=""
    ):
        super().__init__()

        bias = bias and (not bn)
        conv_unit = conv(
            in_size,
            out_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias
        )
        init(conv_unit.weight)
        if bias:
            nn.init.constant_(conv_unit.bias, 0)

        if bn:
            if not preact:
                bn_unit = batch_norm(out_size)
            else:
                bn_unit = batch_norm(in_size)

        if preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)

        self.add_module(name + 'conv', conv_unit)

        if not preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)


class Conv1d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: int = 1,
            stride: int = 1,
            padding: int = 0,
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = ""
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv1d,
            batch_norm=BatchNorm1d,
            bias=bias,
            preact=preact,
            name=name
        )


class Conv2d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: Tuple[int, int] = (1, 1),
            stride: Tuple[int, int] = (1, 1),
            padding: Tuple[int, int] = (0, 0),
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = ""
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv2d,
            batch_norm=BatchNorm2d,
            bias=bias,
            preact=preact,
            name=name
        )


class Conv3d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: Tuple[int, int, int] = (1, 1, 1),
            stride: Tuple[int, int, int] = (1, 1, 1),
            padding: Tuple[int, int, int] = (0, 0, 0),
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = ""
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv3d,
            batch_norm=BatchNorm3d,
            bias=bias,
            preact=preact,
            name=name
        )


class FC(nn.Sequential):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=None,
            preact: bool = False,
            name: str = ""
    ):
        super().__init__()

        fc = nn.Linear(in_size, out_size, bias=not bn)
        if init is not None:
            init(fc.weight)
        if not bn:
            nn.init.constant_(fc.bias, 0)

        if preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(in_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)

        self.add_module(name + 'fc', fc)

        if not preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(out_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)


class _DropoutNoScaling(InplaceFunction):

    @staticmethod
    def _make_noise(input):
        return input.new().resize_as_(input)

    @staticmethod
    def symbolic(g, input, p=0.5, train=False, inplace=False):
        if inplace:
            return None
        n = g.appendNode(
            g.create("Dropout", [input]).f_("ratio",
                                            p).i_("is_test", not train)
        )
        real = g.appendNode(g.createSelect(n, 0))
        g.appendNode(g.createSelect(n, 1))
        return real

    @classmethod
    def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
        if p < 0 or p > 1:
            raise ValueError(
                "dropout probability has to be between 0 and 1, "
                "but got {}".format(p)
            )
        ctx.p = p
        ctx.train = train
        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if ctx.p > 0 and ctx.train:
            ctx.noise = cls._make_noise(input)
            if ctx.p == 1:
                ctx.noise.fill_(0)
            else:
                ctx.noise.bernoulli_(1 - ctx.p)
            ctx.noise = ctx.noise.expand_as(input)
            output.mul_(ctx.noise)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.p > 0 and ctx.train:
            return grad_output.mul(ctx.noise), None, None, None
        else:
            return grad_output, None, None, None


dropout_no_scaling = _DropoutNoScaling.apply


class _FeatureDropoutNoScaling(_DropoutNoScaling):

    @staticmethod
    def symbolic(input, p=0.5, train=False, inplace=False):
        return None

    @staticmethod
    def _make_noise(input):
        return input.new().resize_(
            input.size(0), input.size(1), *repeat(1,
                                                  input.dim() - 2)
        )


feature_dropout_no_scaling = _FeatureDropoutNoScaling.apply


def group_model_params(model: nn.Module, **kwargs):
    decay_group = []
    no_decay_group = []

    for name, param in model.named_parameters():
        if name.find("bn") != -1 or name.find("bias") != -1:
            no_decay_group.append(param)
        else:
            decay_group.append(param)

    assert len(list(model.parameters())) == len(decay_group) + len(no_decay_group)

    return [
        dict(params=decay_group, **kwargs),
        dict(params=no_decay_group, weight_decay=0.0, **kwargs)
    ]


def checkpoint_state(
        model=None, optimizer=None, best_prec=None, epoch=None, it=None
):
    optim_state = optimizer.state_dict() if optimizer is not None else None
    if model is not None:
        if isinstance(model, torch.nn.DataParallel):
            model_state = model.module.state_dict()
        else:
            model_state = model.state_dict()
    else:
        model_state = None

    return {
        'epoch': epoch,
        'it': it,
        'best_prec': best_prec,
        'model_state': model_state,
        'optimizer_state': optim_state
    }


def save_checkpoint(
        state, is_best, filename='checkpoint', bestname='model_best'
):
    filename = '{}.pth.tar'.format(filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, '{}.pth.tar'.format(bestname))


def load_checkpoint(model=None, optimizer=None, filename='checkpoint'):
    filename = "{}.pth.tar".format(filename)
    if os.path.isfile(filename):
        print("==> Loading from checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        epoch = checkpoint['epoch']
        it = checkpoint.get('it', 0.0)
        best_prec = checkpoint['best_prec']
        if model is not None and checkpoint['model_state'] is not None:
            model.load_state_dict(checkpoint['model_state'])
        if optimizer is not None and checkpoint['optimizer_state'] is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state'])
        print("==> Done")
    else:
        print("==> Checkpoint '{}' not found".format(filename))

    return it, epoch, best_prec


def variable_size_collate(pad_val=0, use_shared_memory=True):
    import collections
    _numpy_type_map = {
        'float64': torch.DoubleTensor,
        'float32': torch.FloatTensor,
        'float16': torch.Ha
Download .txt
gitextract_6_2c5ysk/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── s3dis/
│   │   └── s3dis_names.txt
│   └── scannet/
│       └── scannet_names.txt
├── lib/
│   ├── __init__.py
│   ├── pointops/
│   │   ├── __init__.py
│   │   ├── functions/
│   │   │   ├── __init__.py
│   │   │   └── pointops.py
│   │   ├── setup.py
│   │   └── src/
│   │       ├── __init__.py
│   │       ├── ballquery/
│   │       │   ├── ballquery_cuda.cpp
│   │       │   ├── ballquery_cuda_kernel.cu
│   │       │   └── ballquery_cuda_kernel.h
│   │       ├── cuda_utils.h
│   │       ├── featuredistribute/
│   │       │   ├── featuredistribute_cuda.cpp
│   │       │   ├── featuredistribute_cuda_kernel.cu
│   │       │   └── featuredistribute_cuda_kernel.h
│   │       ├── grouping/
│   │       │   ├── grouping_cuda.cpp
│   │       │   ├── grouping_cuda_kernel.cu
│   │       │   └── grouping_cuda_kernel.h
│   │       ├── grouping_int/
│   │       │   ├── grouping_int_cuda.cpp
│   │       │   ├── grouping_int_cuda_kernel.cu
│   │       │   └── grouping_int_cuda_kernel.h
│   │       ├── interpolation/
│   │       │   ├── interpolation_cuda.cpp
│   │       │   ├── interpolation_cuda_kernel.cu
│   │       │   └── interpolation_cuda_kernel.h
│   │       ├── knnquery/
│   │       │   ├── __init__.py
│   │       │   ├── knnquery_cuda.cpp
│   │       │   ├── knnquery_cuda_kernel.cu
│   │       │   └── knnquery_cuda_kernel.h
│   │       ├── labelstat/
│   │       │   ├── labelstat_cuda.cpp
│   │       │   ├── labelstat_cuda_kernel.cu
│   │       │   └── labelstat_cuda_kernel.h
│   │       ├── pointops_api.cpp
│   │       └── sampling/
│   │           ├── sampling_cuda.cpp
│   │           ├── sampling_cuda_kernel.cu
│   │           └── sampling_cuda_kernel.h
│   └── sync_bn/
│       ├── __init__.py
│       ├── batchnorm.py
│       ├── comm.py
│       ├── replicate.py
│       └── unittest.py
├── model/
│   ├── __init__.py
│   ├── pointnet/
│   │   └── pointnet.py
│   ├── pointnet2/
│   │   ├── pointnet2_modules.py
│   │   └── pointnet2_seg.py
│   └── pointweb/
│       ├── pointweb_module.py
│       └── pointweb_seg.py
├── tool/
│   ├── test.sh
│   ├── test_s3dis.py
│   ├── test_scannet.py
│   ├── train.py
│   └── train.sh
└── util/
    ├── config.py
    ├── dataset.py
    ├── pt_util.py
    ├── s3dis.py
    ├── scannet.py
    ├── transform.py
    └── util.py
Download .txt
SYMBOL INDEX (276 symbols across 30 files)

FILE: lib/pointops/functions/pointops.py
  class FurthestSampling (line 10) | class FurthestSampling(Function):
    method forward (line 12) | def forward(ctx, xyz, m):
    method backward (line 25) | def backward(xyz, a=None):
  class Gathering (line 31) | class Gathering(Function):
    method forward (line 33) | def forward(ctx, features, idx):
    method backward (line 48) | def backward(ctx, grad_out):
  class NearestNeighbor (line 59) | class NearestNeighbor(Function):
    method forward (line 61) | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[...
    method backward (line 78) | def backward(ctx, a=None, b=None):
  class Interpolation (line 84) | class Interpolation(Function):
    method forward (line 86) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: to...
    method backward (line 105) | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch...
  class Grouping (line 120) | class Grouping(Function):
    method forward (line 122) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.T...
    method backward (line 137) | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch...
  class GroupingInt (line 152) | class GroupingInt(Function):
    method forward (line 154) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.T...
    method backward (line 168) | def backward(ctx, a=None):
  class BallQuery (line 174) | class BallQuery(Function):
    method forward (line 176) | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_x...
    method backward (line 193) | def backward(ctx, a=None):
  class FeatureDistribute (line 199) | class FeatureDistribute(Function):
    method forward (line 201) | def forward(ctx, max_xyz: torch.Tensor, xyz: torch.Tensor) -> torch.Te...
    method backward (line 217) | def backward(ctx, a=None):
  class FeatureGather (line 223) | class FeatureGather(Function):
    method forward (line 225) | def forward(ctx, max_feature: torch.Tensor, distribute_idx: torch.Tens...
    method backward (line 242) | def backward(ctx, grad_distribute_feature: torch.Tensor):
  class LabelStatBallRange (line 258) | class LabelStatBallRange(Function):
    method forward (line 260) | def forward(ctx, radius: float, xyz: torch.Tensor, new_xyz: torch.Tens...
    method backward (line 281) | def backward(ctx, a=None):
  class LabelStatIdx (line 287) | class LabelStatIdx(Function):
    method forward (line 289) | def forward(ctx, nsample: int, label_stat: torch.Tensor, idx: torch.Te...
    method backward (line 308) | def backward(ctx, a=None):
  class LabelStatAndBallQuery (line 314) | class LabelStatAndBallQuery(Function):
    method forward (line 316) | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_x...
    method backward (line 340) | def backward(ctx, a=None, b=None):
  function pairwise_distances (line 346) | def pairwise_distances(x, y=None):
  class KNNQueryNaive (line 366) | class KNNQueryNaive(Function):
    method forward (line 368) | def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tenso...
    method backward (line 400) | def backward(ctx):
  class KNNQuery (line 406) | class KNNQuery(Function):
    method forward (line 408) | def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tenso...
    method backward (line 429) | def backward(ctx, a=None):
  class KNNQueryExclude (line 435) | class KNNQueryExclude(Function):
    method forward (line 437) | def forward(ctx, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tenso...
    method backward (line 469) | def backward(ctx):
  class QueryAndGroup (line 475) | class QueryAndGroup(nn.Module):
    method __init__ (line 482) | def __init__(self, radius=None, nsample=32, use_xyz=True):
    method forward (line 486) | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor = None, fea...
  class GroupAll (line 521) | class GroupAll(nn.Module):
    method __init__ (line 525) | def __init__(self, use_xyz: bool = True):
    method forward (line 529) | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: ...

FILE: lib/pointops/src/ballquery/ballquery_cuda.cpp
  function ballquery_cuda (line 14) | void ballquery_cuda(int b, int n, int m, float radius, int nsample, at::...
  function ballquery_cuda_fast (line 24) | void ballquery_cuda_fast(int b, int n, int m, float radius, int nsample,...

FILE: lib/pointops/src/cuda_utils.h
  function opt_n_threads (line 15) | inline int opt_n_threads(int work_size) {
  function dim3 (line 20) | inline dim3 opt_block_config(int x, int y) {

FILE: lib/pointops/src/featuredistribute/featuredistribute_cuda.cpp
  function featuredistribute_cuda (line 15) | void featuredistribute_cuda(int b, int n, int m, at::Tensor max_xyz_tens...
  function featuregather_forward_cuda (line 30) | void featuregather_forward_cuda(int b, int n, int m, int c, at::Tensor m...
  function featuregather_backward_cuda (line 45) | void featuregather_backward_cuda(int b, int n, int m, int c, at::Tensor ...

FILE: lib/pointops/src/grouping/grouping_cuda.cpp
  function grouping_forward_cuda (line 10) | void grouping_forward_cuda(int b, int c, int n, int m, int nsample, at::...
  function grouping_backward_cuda (line 18) | void grouping_backward_cuda(int b, int c, int n, int m, int nsample, at:...
  function grouping_forward_cuda_fast (line 26) | void grouping_forward_cuda_fast(int b, int c, int n, int npoints, int ns...

FILE: lib/pointops/src/grouping_int/grouping_int_cuda.cpp
  function grouping_int_forward_cuda (line 10) | void grouping_int_forward_cuda(int b, int c, int n, int m, int nsample, ...
  function grouping_int_forward_cuda_fast (line 18) | void grouping_int_forward_cuda_fast(int b, int c, int n, int m, int nsam...

FILE: lib/pointops/src/interpolation/interpolation_cuda.cpp
  function nearestneighbor_cuda (line 9) | void nearestneighbor_cuda(int b, int n, int m, at::Tensor unknown_tensor...
  function interpolation_forward_cuda (line 18) | void interpolation_forward_cuda(int b, int c, int m, int n, at::Tensor p...
  function interpolation_backward_cuda (line 27) | void interpolation_backward_cuda(int b, int c, int n, int m, at::Tensor ...
  function nearestneighbor_cuda_fast (line 36) | void nearestneighbor_cuda_fast(int b, int n, int m, at::Tensor unknown_t...
  function interpolation_forward_cuda_fast (line 44) | void interpolation_forward_cuda_fast(int b, int c, int m, int n, at::Ten...

FILE: lib/pointops/src/knnquery/knnquery_cuda.cpp
  function knnquery_cuda (line 15) | void knnquery_cuda(int b, int n, int m, int nsample, at::Tensor xyz_tens...

FILE: lib/pointops/src/labelstat/labelstat_cuda.cpp
  function labelstat_idx_cuda_fast (line 14) | void labelstat_idx_cuda_fast(int b, int n, int m, int nsample, int nclass,
  function labelstat_ballrange_cuda_fast (line 29) | void labelstat_ballrange_cuda_fast(int b, int n, int m, float radius, in...
  function labelstat_and_ballquery_cuda_fast (line 46) | void labelstat_and_ballquery_cuda_fast(int b, int n, int m, float radius...

FILE: lib/pointops/src/pointops_api.cpp
  function PYBIND11_MODULE (line 15) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: lib/pointops/src/sampling/sampling_cuda.cpp
  function gathering_forward_cuda (line 9) | void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor point...
  function gathering_backward_cuda (line 17) | void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad...
  function furthestsampling_cuda (line 26) | void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor...

FILE: lib/sync_bn/batchnorm.py
  function _sum_ft (line 24) | def _sum_ft(tensor):
  function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
  class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
    method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
    method forward (line 48) | def forward(self, input):
    method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
    method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
    method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
  class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
    method _check_input_dim (line 184) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
    method _check_input_dim (line 247) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
    method _check_input_dim (line 311) | def _check_input_dim(self, input):

FILE: lib/sync_bn/comm.py
  class FutureResult (line 18) | class FutureResult(object):
    method __init__ (line 21) | def __init__(self):
    method put (line 26) | def put(self, result):
    method get (line 32) | def get(self):
  class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
    method run_slave (line 49) | def run_slave(self, msg):
  class SyncMaster (line 56) | class SyncMaster(object):
    method __init__ (line 67) | def __init__(self, master_callback):
    method __getstate__ (line 78) | def __getstate__(self):
    method __setstate__ (line 81) | def __setstate__(self, state):
    method register_slave (line 84) | def register_slave(self, identifier):
    method run_master (line 102) | def run_master(self, master_msg):
    method nr_slaves (line 136) | def nr_slaves(self):

FILE: lib/sync_bn/replicate.py
  class CallbackContext (line 23) | class CallbackContext(object):
  function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
  class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
    method replicate (line 64) | def replicate(self, module, device_ids):
  function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):

FILE: lib/sync_bn/unittest.py
  function as_numpy (line 17) | def as_numpy(v):
  class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
    method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):

FILE: model/pointnet/pointnet.py
  class STN3D (line 6) | class STN3D(nn.Module):
    method __init__ (line 7) | def __init__(self, c):
    method forward (line 24) | def forward(self, x):
  class PointNetFeat (line 43) | class PointNetFeat(nn.Module):
    method __init__ (line 44) | def __init__(self, c=3, global_feat=True):
    method forward (line 62) | def forward(self, x):
  class PointNetCls (line 81) | class PointNetCls(nn.Module):
    method __init__ (line 82) | def __init__(self, c=3, k=40, dropout=0.3, sync_bn=False):
    method forward (line 93) | def forward(self, x):
  class PointNetSeg (line 104) | class PointNetSeg(nn.Module):
    method __init__ (line 105) | def __init__(self, c=9, k=13, sync_bn=False):
    method forward (line 119) | def forward(self, x):

FILE: model/pointnet2/pointnet2_modules.py
  class _PointNet2SAModuleBase (line 11) | class _PointNet2SAModuleBase(nn.Module):
    method __init__ (line 12) | def __init__(self):
    method forward (line 18) | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) ->...
  class PointNet2SAModuleMSG (line 48) | class PointNet2SAModuleMSG(_PointNet2SAModuleBase):
    method __init__ (line 63) | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[...
  class PointNet2SAModule (line 82) | class PointNet2SAModule(PointNet2SAModuleMSG):
    method __init__ (line 97) | def __init__(self, *, mlp: List[int], npoint: int = None, radius: floa...
  class PointNet2FPModule (line 101) | class PointNet2FPModule(nn.Module):
    method __init__ (line 110) | def __init__(self, *, mlp: List[int], bn: bool = True):
    method forward (line 114) | def forward(self, unknown: torch.Tensor, known: torch.Tensor, unknow_f...

FILE: model/pointnet2/pointnet2_seg.py
  class PointNet2SSGSeg (line 10) | class PointNet2SSGSeg(nn.Module):
    method __init__ (line 25) | def __init__(self, c=3, k=13, use_xyz=True):
    method _break_up_pc (line 39) | def _break_up_pc(self, pc):
    method forward (line 44) | def forward(self, pointcloud: torch.cuda.FloatTensor):
  class PointNet2MSGSeg (line 67) | class PointNet2MSGSeg(PointNet2SSGSeg):
    method __init__ (line 82) | def __init__(self, k, c=6, use_xyz=True):
  function model_fn_decorator (line 105) | def model_fn_decorator(criterion):

FILE: model/pointweb/pointweb_module.py
  class _AFAModule (line 11) | class _AFAModule(nn.Module):
    method __init__ (line 12) | def __init__(self, mlp, use_softmax=False):
    method forward (line 21) | def forward(self, feature: torch.Tensor) -> torch.Tensor:
  class _PointWebSAModuleBase (line 43) | class _PointWebSAModuleBase(nn.Module):
    method __init__ (line 44) | def __init__(self):
    method forward (line 51) | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None) ->...
  class PointWebSAModule (line 81) | class PointWebSAModule(_PointWebSAModuleBase):
    method __init__ (line 96) | def __init__(self, *, npoint: int = None, nsample: int = None, mlp: Li...

FILE: model/pointweb/pointweb_seg.py
  class PointWebSeg (line 11) | class PointWebSeg(nn.Module):
    method __init__ (line 26) | def __init__(self, c=3, k=13, use_xyz=True):
    method _break_up_pc (line 41) | def _break_up_pc(self, pc):
    method forward (line 46) | def forward(self, pointcloud: torch.cuda.FloatTensor):
  function model_fn_decorator (line 68) | def model_fn_decorator(criterion):
  function model_fn_decorator (line 84) | def model_fn_decorator(criterion):

FILE: tool/test_s3dis.py
  function get_parser (line 22) | def get_parser():
  function get_logger (line 34) | def get_logger():
  function main (line 45) | def main():
  function data_prepare (line 77) | def data_prepare(room_path):
  function test (line 119) | def test(model, criterion, names):

FILE: tool/test_scannet.py
  function get_parser (line 22) | def get_parser():
  function get_logger (line 34) | def get_logger():
  function main (line 45) | def main():
  function data_prepare (line 77) | def data_prepare(points, labels):
  function test (line 116) | def test(model, criterion, names):

FILE: tool/train.py
  function get_parser (line 23) | def get_parser():
  function get_logger (line 35) | def get_logger():
  function worker_init_fn (line 46) | def worker_init_fn(worker_id):
  function init (line 50) | def init():
  function main (line 68) | def main():
  function train (line 154) | def train(train_loader, model, criterion, optimizer, epoch):
  function validate (line 220) | def validate(val_loader, model, criterion):

FILE: util/config.py
  class CfgNode (line 10) | class CfgNode(dict):
    method __init__ (line 16) | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
    method __getattr__ (line 26) | def __getattr__(self, name):
    method __setattr__ (line 32) | def __setattr__(self, name, value):
    method __str__ (line 35) | def __str__(self):
    method __repr__ (line 56) | def __repr__(self):
  function load_cfg_from_cfg_file (line 60) | def load_cfg_from_cfg_file(file):
  function merge_cfg_from_list (line 76) | def merge_cfg_from_list(cfg, cfg_list):
  function _decode_cfg_value (line 91) | def _decode_cfg_value(v):
  function _check_and_coerce_cfg_value_type (line 120) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
  function _assert_with_logging (line 162) | def _assert_with_logging(cond, msg):

FILE: util/dataset.py
  function make_dataset (line 8) | def make_dataset(split='train', data_root=None, data_list=None):
  class PointData (line 19) | class PointData(Dataset):
    method __init__ (line 20) | def __init__(self, split='train', data_root=None, data_list=None, tran...
    method __len__ (line 28) | def __len__(self):
    method __getitem__ (line 31) | def __getitem__(self, index):

FILE: util/pt_util.py
  class SharedMLP (line 15) | class SharedMLP(nn.Sequential):
    method __init__ (line 17) | def __init__(
  class _BNBase (line 43) | class _BNBase(nn.Sequential):
    method __init__ (line 45) | def __init__(self, in_size, batch_norm=None, name=""):
  class BatchNorm1d (line 53) | class BatchNorm1d(_BNBase):
    method __init__ (line 55) | def __init__(self, in_size: int, *, name: str = ""):
  class BatchNorm2d (line 59) | class BatchNorm2d(_BNBase):
    method __init__ (line 61) | def __init__(self, in_size: int, name: str = ""):
  class BatchNorm3d (line 65) | class BatchNorm3d(_BNBase):
    method __init__ (line 67) | def __init__(self, in_size: int, name: str = ""):
  class _ConvBase (line 71) | class _ConvBase(nn.Sequential):
    method __init__ (line 73) | def __init__(
  class Conv1d (line 127) | class Conv1d(_ConvBase):
    method __init__ (line 129) | def __init__(
  class Conv2d (line 161) | class Conv2d(_ConvBase):
    method __init__ (line 163) | def __init__(
  class Conv3d (line 195) | class Conv3d(_ConvBase):
    method __init__ (line 197) | def __init__(
  class FC (line 229) | class FC(nn.Sequential):
    method __init__ (line 231) | def __init__(
  class _DropoutNoScaling (line 267) | class _DropoutNoScaling(InplaceFunction):
    method _make_noise (line 270) | def _make_noise(input):
    method symbolic (line 274) | def symbolic(g, input, p=0.5, train=False, inplace=False):
    method forward (line 286) | def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
    method backward (line 314) | def backward(ctx, grad_output):
  class _FeatureDropoutNoScaling (line 324) | class _FeatureDropoutNoScaling(_DropoutNoScaling):
    method symbolic (line 327) | def symbolic(input, p=0.5, train=False, inplace=False):
    method _make_noise (line 331) | def _make_noise(input):
  function group_model_params (line 341) | def group_model_params(model: nn.Module, **kwargs):
  function checkpoint_state (line 359) | def checkpoint_state(
  function save_checkpoint (line 380) | def save_checkpoint(
  function load_checkpoint (line 389) | def load_checkpoint(model=None, optimizer=None, filename='checkpoint'):
  function variable_size_collate (line 408) | def variable_size_collate(pad_val=0, use_shared_memory=True):
  class TrainValSplitter (line 478) | class TrainValSplitter():
    method __init__ (line 491) | def __init__(
  function set_bn_momentum_default (line 576) | def set_bn_momentum_default(bn_momentum):
  class BNMomentumScheduler (line 585) | class BNMomentumScheduler(object):
    method __init__ (line 587) | def __init__(
    method step (line 605) | def step(self, epoch=None):
  class Trainer (line 613) | class Trainer(object):
    method __init__ (line 638) | def __init__(
    method _decode_value (line 661) | def _decode_value(v):
    method _train_it (line 680) | def _train_it(self, it, batch):
    method eval_epoch (line 697) | def eval_epoch(self, d_loader):
    method train (line 717) | def train(

FILE: util/s3dis.py
  class S3DIS (line 7) | class S3DIS(Dataset):
    method __init__ (line 8) | def __init__(self, split='train', data_root='trainval_fullarea', num_p...
    method __getitem__ (line 38) | def __getitem__(self, idx):
    method __len__ (line 72) | def __len__(self):
  function worker_init_fn (line 90) | def worker_init_fn(worker_id):

FILE: util/scannet.py
  class ScanNet (line 8) | class ScanNet(Dataset):
    method __init__ (line 9) | def __init__(self, split='train', data_root='scannet', num_point=8192,...
    method __getitem__ (line 47) | def __getitem__(self, idx):
    method __len__ (line 84) | def __len__(self):
  function worker_init_fn (line 96) | def worker_init_fn(worker_id):

FILE: util/transform.py
  class Compose (line 6) | class Compose(object):
    method __init__ (line 7) | def __init__(self, transforms):
    method __call__ (line 10) | def __call__(self, data, label):
  class ToTensor (line 16) | class ToTensor(object):
    method __call__ (line 17) | def __call__(self, data, label):
  class RandomRotate (line 27) | class RandomRotate(object):
    method __init__ (line 28) | def __init__(self, rotate_angle=None, along_z=False):
    method __call__ (line 32) | def __call__(self, data, label):
  class RandomRotatePerturbation (line 48) | class RandomRotatePerturbation(object):
    method __init__ (line 49) | def __init__(self, angle_sigma=0.06, angle_clip=0.18):
    method __call__ (line 53) | def __call__(self, data, label):
  class RandomScale (line 71) | class RandomScale(object):
    method __init__ (line 72) | def __init__(self, scale_low=0.8, scale_high=1.25):
    method __call__ (line 76) | def __call__(self, data, label):
  class RandomShift (line 82) | class RandomShift(object):
    method __init__ (line 83) | def __init__(self, shift_range=0.1):
    method __call__ (line 86) | def __call__(self, data, label):
  class RandomJitter (line 92) | class RandomJitter(object):
    method __init__ (line 93) | def __init__(self, sigma=0.01, clip=0.05):
    method __call__ (line 97) | def __call__(self, data, label):

FILE: util/util.py
  class AverageMeter (line 12) | class AverageMeter(object):
    method __init__ (line 14) | def __init__(self):
    method reset (line 17) | def reset(self):
    method update (line 23) | def update(self, val, n=1):
  function step_learning_rate (line 30) | def step_learning_rate(optimizer, base_lr, epoch, step_epoch, multiplier...
  function poly_learning_rate (line 37) | def poly_learning_rate(optimizer, base_lr, curr_iter, max_iter, power=0.9):
  function intersectionAndUnion (line 44) | def intersectionAndUnion(output, target, K, ignore_index=255):
  function intersectionAndUnionGPU (line 59) | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
  function check_mkdir (line 75) | def check_mkdir(dir_name):
  function check_makedirs (line 80) | def check_makedirs(dir_name):
  function init_weights (line 85) | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaim...
  function convert_to_syncbn (line 136) | def convert_to_syncbn(model):
  function colorize (line 152) | def colorize(gray, palette):
Condensed preview — 61 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (235K chars).
[
  {
    "path": ".gitignore",
    "chars": 798,
    "preview": "## General\r\n\r\n# Compiled Object files\r\n*.slo\r\n*.lo\r\n*.o\r\n*.cuo\r\n\r\n# Compiled Dynamic libraries\r\n*.so\r\n*.dylib\r\n\r\n# Compi"
  },
  {
    "path": "LICENSE",
    "chars": 1072,
    "preview": "MIT License\n\nCopyright (c) 2019 Hengshuang Zhao\n\nPermission is hereby granted, free of charge, to any person obtaining a"
  },
  {
    "path": "README.md",
    "chars": 2609,
    "preview": "# PointWeb: Enhancing Local Neighborhood Features for Point Cloud Processing\n\nby Hengshuang Zhao\\*, Li Jiang*, Chi-Wing "
  },
  {
    "path": "data/s3dis/s3dis_names.txt",
    "chars": 83,
    "preview": "ceiling\nfloor\nwall\nbeam\ncolumn\nwindow\ndoor\nchair\ntable\nbookcase\nsofa\nboard\nclutter\n"
  },
  {
    "path": "data/scannet/scannet_names.txt",
    "chars": 153,
    "preview": "bathtub\nbed\nbookshelf\ncabinet\nchair\ncounter\ncurtain\ndesk\ndoor\nfloor\notherfurniture\npicture\nrefrigerator\nshowercurtain\nsi"
  },
  {
    "path": "lib/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/pointops/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/pointops/functions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/pointops/functions/pointops.py",
    "chars": 19125,
    "preview": "from typing import Tuple\n\nimport torch\nfrom torch.autograd import Function\nimport torch.nn as nn\n\nimport pointops_cuda\n\n"
  },
  {
    "path": "lib/pointops/setup.py",
    "chars": 1292,
    "preview": "#python3 setup.py install\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtensi"
  },
  {
    "path": "lib/pointops/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/pointops/src/ballquery/ballquery_cuda.cpp",
    "chars": 1269,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/ballquery/ballquery_cuda_kernel.cu",
    "chars": 3389,
    "preview": "#include \"../cuda_utils.h\"\n#include \"ballquery_cuda_kernel.h\"\n\n// input: new_xyz(b, m, 3) xyz(b, n, 3)\n// output: idx(b,"
  },
  {
    "path": "lib/pointops/src/ballquery/ballquery_cuda_kernel.h",
    "chars": 803,
    "preview": "#ifndef _BALLQUERY_CUDA_KERNEL\n#define _BALLQUERY_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vector>\n#inc"
  },
  {
    "path": "lib/pointops/src/cuda_utils.h",
    "chars": 825,
    "preview": "#ifndef _CUDA_UTILS_H\n#define _CUDA_UTILS_H\n\n#include <cmath>\n\n#define TOTAL_THREADS 1024\n\n#define CHECK_CUDA(x) AT_CHEC"
  },
  {
    "path": "lib/pointops/src/featuredistribute/featuredistribute_cuda.cpp",
    "chars": 2218,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/featuredistribute/featuredistribute_cuda_kernel.cu",
    "chars": 4533,
    "preview": "#include \"../cuda_utils.h\"\n#include \"featuredistribute_cuda_kernel.h\"\n\n__global__ void featuredistribute_cuda_kernel(int"
  },
  {
    "path": "lib/pointops/src/featuredistribute/featuredistribute_cuda_kernel.h",
    "chars": 1213,
    "preview": "#ifndef _FEATUREDISTRIBUTE_CUDA_KERNEL\n#define _FEATUREDISTRIBUTE_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#inclu"
  },
  {
    "path": "lib/pointops/src/grouping/grouping_cuda.cpp",
    "chars": 1323,
    "preview": "#include <torch/serialize/tensor.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <vector>\n#include <THC/THC.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/grouping/grouping_cuda_kernel.cu",
    "chars": 3686,
    "preview": "#include \"../cuda_utils.h\"\n#include \"grouping_cuda_kernel.h\"\n\n// input: points(b, c, n) idx(b, m, nsample)\n// output: ou"
  },
  {
    "path": "lib/pointops/src/grouping/grouping_cuda_kernel.h",
    "chars": 1070,
    "preview": "#ifndef _GROUPING_CUDA_KERNEL\n#define _GROUPING_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vector>\n#inclu"
  },
  {
    "path": "lib/pointops/src/grouping_int/grouping_int_cuda.cpp",
    "chars": 949,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THC.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/grouping_int/grouping_int_cuda_kernel.cu",
    "chars": 2479,
    "preview": "#include \"../cuda_utils.h\"\n#include \"grouping_int_cuda_kernel.h\"\n\n// input: points(b, c, n) idx(b, m, nsample)\n// output"
  },
  {
    "path": "lib/pointops/src/grouping_int/grouping_int_cuda_kernel.h",
    "chars": 810,
    "preview": "#ifndef _GROUPING_INT_CUDA_KERNEL\n#define _GROUPING_INT_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vector"
  },
  {
    "path": "lib/pointops/src/interpolation/interpolation_cuda.cpp",
    "chars": 2396,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THC.h>\n#include \""
  },
  {
    "path": "lib/pointops/src/interpolation/interpolation_cuda_kernel.cu",
    "chars": 7415,
    "preview": "#include \"../cuda_utils.h\"\n#include \"interpolation_cuda_kernel.h\"\n\n// input: unknown(b, n, 3) known(b, m, 3)\n// output: "
  },
  {
    "path": "lib/pointops/src/interpolation/interpolation_cuda_kernel.h",
    "chars": 1720,
    "preview": "#ifndef _INTERPOLATION_CUDA_KERNEL\n#define _INTERPOLATION_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vect"
  },
  {
    "path": "lib/pointops/src/knnquery/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/pointops/src/knnquery/knnquery_cuda.cpp",
    "chars": 947,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/knnquery/knnquery_cuda_kernel.cu",
    "chars": 2358,
    "preview": "#include \"../cuda_utils.h\"\n#include \"knnquery_cuda_kernel.h\"\n\n// input: xyz (b, n, 3) new_xyz (b, m, 3)\n// output: idx ("
  },
  {
    "path": "lib/pointops/src/knnquery/knnquery_cuda_kernel.h",
    "chars": 528,
    "preview": "#ifndef _KNNQUERY_CUDA_KERNEL\n#define _KNNQUERY_CUDA_KERNEL\n\n#include <torch/serialize/tensor.h>\n#include <vector>\n#incl"
  },
  {
    "path": "lib/pointops/src/labelstat/labelstat_cuda.cpp",
    "chars": 2511,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <ATen/cuda/CUDAContext.h>\n\n#include "
  },
  {
    "path": "lib/pointops/src/labelstat/labelstat_cuda_kernel.cu",
    "chars": 6145,
    "preview": "#include \"../cuda_utils.h\"\n#include \"labelstat_cuda_kernel.h\"\n\n// input: new_xyz(b, m, 3) xyz(b, n, 3) label_stat(b, n, "
  },
  {
    "path": "lib/pointops/src/labelstat/labelstat_cuda_kernel.h",
    "chars": 1448,
    "preview": "#ifndef _LABELSTAT_CUDA_KERNEL\n#define _LABELSTAT_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vector>\n#inc"
  },
  {
    "path": "lib/pointops/src/pointops_api.cpp",
    "chars": 2144,
    "preview": "#include <torch/serialize/tensor.h>\n#include <torch/extension.h>\n\n#include \"ballquery/ballquery_cuda_kernel.h\"\n#include "
  },
  {
    "path": "lib/pointops/src/sampling/sampling_cuda.cpp",
    "chars": 1227,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THC.h>\n#include \""
  },
  {
    "path": "lib/pointops/src/sampling/sampling_cuda_kernel.cu",
    "chars": 6625,
    "preview": "#include \"../cuda_utils.h\"\n#include \"sampling_cuda_kernel.h\"\n\n// input: points(b, c, n) idx(b, m)\n// output: out(b, c, m"
  },
  {
    "path": "lib/pointops/src/sampling/sampling_cuda_kernel.h",
    "chars": 963,
    "preview": "#ifndef _SAMPLING_CUDA_KERNEL\n#define _SAMPLING_CUDA_KERNEL\n#include <torch/serialize/tensor.h>\n#include <vector>\n#inclu"
  },
  {
    "path": "lib/sync_bn/__init__.py",
    "chars": 449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "lib/sync_bn/batchnorm.py",
    "chars": 12973,
    "preview": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "lib/sync_bn/comm.py",
    "chars": 4449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n"
  },
  {
    "path": "lib/sync_bn/replicate.py",
    "chars": 3227,
    "preview": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "lib/sync_bn/unittest.py",
    "chars": 835,
    "preview": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "model/pointnet/pointnet.py",
    "chars": 4841,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass STN3D(nn.Module):\n    def __init__(self, c):\n"
  },
  {
    "path": "model/pointnet2/pointnet2_modules.py",
    "chars": 6284,
    "preview": "from typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom lib.pointops.functions"
  },
  {
    "path": "model/pointnet2/pointnet2_seg.py",
    "chars": 7812,
    "preview": "from collections import namedtuple\n\nimport torch\nimport torch.nn as nn\n\nfrom model.pointnet2.pointnet2_modules import Po"
  },
  {
    "path": "model/pointweb/pointweb_module.py",
    "chars": 4736,
    "preview": "from typing import List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom lib.pointops.functions"
  },
  {
    "path": "model/pointweb/pointweb_seg.py",
    "chars": 5478,
    "preview": "from collections import namedtuple\n\nimport torch\nimport torch.nn as nn\n\nfrom model.pointweb.pointweb_module import Point"
  },
  {
    "path": "tool/test.sh",
    "chars": 602,
    "preview": "#!/bin/sh\nexport PYTHONPATH=./\n\nPYTHON=python\ndataset=$1\nexp_name=$2\nexp_dir=exp/${dataset}/${exp_name}\nmodel_dir=${exp_"
  },
  {
    "path": "tool/test_s3dis.py",
    "chars": 10595,
    "preview": "import os\nimport time\nimport random\nimport numpy as np\nimport logging\nimport pickle\nimport argparse\n\nimport torch\nimport"
  },
  {
    "path": "tool/test_scannet.py",
    "chars": 11228,
    "preview": "import os\nimport time\nimport random\nimport numpy as np\nimport logging\nimport pickle\nimport argparse\n\nimport torch\nimport"
  },
  {
    "path": "tool/train.py",
    "chars": 13191,
    "preview": "import os\nimport time\nimport random\nimport numpy as np\nimport logging\nimport argparse\n\nimport torch\nimport torch.backend"
  },
  {
    "path": "tool/train.sh",
    "chars": 617,
    "preview": "#!/bin/sh\nexport PYTHONPATH=./\n\nPYTHON=python\ndataset=$1\nexp_name=$2\nexp_dir=exp/${dataset}/${exp_name}\nmodel_dir=${exp_"
  },
  {
    "path": "util/config.py",
    "chars": 5425,
    "preview": "# -----------------------------------------------------------------------------\n# Functions for parsing args\n# ---------"
  },
  {
    "path": "util/dataset.py",
    "chars": 2108,
    "preview": "import os\nimport h5py\nimport numpy as np\n\nfrom torch.utils.data import Dataset\n\n\ndef make_dataset(split='train', data_ro"
  },
  {
    "path": "util/pt_util.py",
    "chars": 23123,
    "preview": "import shutil, os\nimport tqdm\nfrom itertools import repeat\nimport numpy as np\nfrom typing import List, Tuple\n# from scip"
  },
  {
    "path": "util/s3dis.py",
    "chars": 4863,
    "preview": "import os\nimport numpy as np\n\nfrom torch.utils.data import Dataset\n\n\nclass S3DIS(Dataset):\n    def __init__(self, split="
  },
  {
    "path": "util/scannet.py",
    "chars": 5359,
    "preview": "import pickle\nimport os\nimport numpy as np\n\nfrom torch.utils.data import Dataset\n\n\nclass ScanNet(Dataset):\n    def __ini"
  },
  {
    "path": "util/transform.py",
    "chars": 3433,
    "preview": "import numpy as np\n\nimport torch\n\n\nclass Compose(object):\n    def __init__(self, transforms):\n        self.transforms = "
  },
  {
    "path": "util/util.py",
    "chars": 5996,
    "preview": "import os\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nfrom torch import nn\nfrom torch.nn.modules.conv import "
  }
]

About this extraction

This page contains the full source code of the hszhao/PointWeb GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 61 files (218.5 KB), approximately 61.1k tokens, and a symbol index with 276 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!