Full Code of WanquanF/NeuralPoints for AI

main a132ddd01ce3 cached
35 files
166.5 KB
47.8k tokens
154 symbols
1 requests
Download .txt
Repository: WanquanF/NeuralPoints
Branch: main
Commit: a132ddd01ce3
Files: 35
Total size: 166.5 KB

Directory structure:
gitextract_1wh9uor_/

├── README.md
├── code/
│   ├── colormap.py
│   ├── mesh_operations.py
│   └── torch_tensor_functions.py
├── model/
│   └── conpu_v6/
│       ├── chamfer_distance/
│       │   ├── __init__.py
│       │   ├── chamfer_distance.cpp
│       │   ├── chamfer_distance.cu
│       │   ├── chamfer_distance.py
│       │   └── setup.py
│       ├── loss.py
│       ├── network.py
│       ├── pointnet2/
│       │   ├── __init__.py
│       │   ├── pointnet2_modules.py
│       │   ├── pointnet2_utils.py
│       │   ├── pytorch_utils.py
│       │   ├── setup.py
│       │   └── src/
│       │       ├── ball_query.cpp
│       │       ├── ball_query_gpu.cu
│       │       ├── ball_query_gpu.h
│       │       ├── cuda_utils.h
│       │       ├── group_points.cpp
│       │       ├── group_points_gpu.cu
│       │       ├── group_points_gpu.h
│       │       ├── interpolate.cpp
│       │       ├── interpolate_gpu.cu
│       │       ├── interpolate_gpu.h
│       │       ├── pointnet2_api.cpp
│       │       ├── sampling.cpp
│       │       ├── sampling_gpu.cu
│       │       └── sampling_gpu.h
│       ├── pre_trained/
│       │   └── v3.pt
│       ├── train_script101.py
│       ├── train_script101_test.py
│       └── train_view_toy.py
└── utils/
    └── config.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Neural Points
【Code of CVPR 2022 paper】Neural Points: Point Cloud Representation with Neural Fields for Arbitrary Upsampling (CVPR 2022).

- Paper address: [https://arxiv.org/abs/2112.04148](https://arxiv.org/abs/2112.04148)
- Project webpage: [https://wanquanf.github.io/NeuralPoints.html](https://wanquanf.github.io/NeuralPoints.html)


![avatar](./utils/Pipeline_v5.png)

## Prerequisite Installation
The code has been tested on Ubuntu 18, with Python3.8, PyTorch 1.6 and Cuda 10.2:

    conda create --name NePs
    
    conda activate NePs
    
    conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.2 -c pytorch
    
    conda install -c conda-forge igl
    
Before running the code, you need to build the cuda&C++ extensions of Pytorch:

    cd [ProjectPath]/model/model_for_supp/pointnet2
    
    python setup.py install

    
## How to use the code: 
Download our dataset: [dataset](https://pan.baidu.com/s/1BLFobnIkuLqrXsdAAVqA0g), (extracting code: qiqq). Put the 'Sketchfab2' folder into: [ProjectPath]/data.

Firstly, you need to change the working directory: 

    cd [ProjectPath]/model/conpu_v6

To obtain the testing results of the testing set, run:

    python train_script101_test.py

To train our network, run:

    python train_script101.py


## Citation
Please cite this paper with the following bibtex:

    @inproceedings{feng2022np,
        author    = {Wanquan Feng and Jin li and Hongrui Cai and Xiaonan Luo and Juyong Zhang},
        title     = {Neural Points: Point Cloud Representation with Neural Fields for Arbitrary Upsampling},
        booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition (CVPR)},
        year      = {2022}
    }


## Acknowledgement
In this repo, we borrowed the backbone structure from [DGCNN](https://github.com/WangYueFt/dgcnn).


================================================
FILE: code/colormap.py
================================================
rb_colormap_list =[ 0,         0,    0.5625,
         0,         0,    0.6250,
         0,         0,    0.6875,
         0,         0,    0.7500,
         0,         0,    0.8125,
         0,         0,    0.8750,
         0,         0,    0.9375,
         0,         0,    1.0000,
         0,    0.0625,    1.0000,
         0,    0.1250,    1.0000,
         0,    0.1875,    1.0000,
         0,    0.2500,    1.0000,
         0,    0.3125,    1.0000,
         0,    0.3750,    1.0000,
         0,    0.4375,    1.0000,
         0,    0.5000,    1.0000,
         0,    0.5625,    1.0000,
         0,    0.6250,    1.0000,
         0,    0.6875,    1.0000,
         0,    0.7500,    1.0000,
         0,    0.8125,    1.0000,
         0,    0.8750,    1.0000,
         0,    0.9375,    1.0000,
         0,    1.0000,    1.0000,
    0.0625,    1.0000,    0.9375,
    0.1250,    1.0000,    0.8750,
    0.1875,    1.0000,    0.8125,
    0.2500,    1.0000,    0.7500,
    0.3125,    1.0000,    0.6875,
    0.3750,    1.0000,    0.6250,
    0.4375,    1.0000,    0.5625,
    0.5000,    1.0000,    0.5000,
    0.5625,    1.0000,    0.4375,
    0.6250,    1.0000,    0.3750,
    0.6875,    1.0000,    0.3125,
    0.7500,    1.0000,    0.2500,
    0.8125,    1.0000,    0.1875,
    0.8750,    1.0000,    0.1250,
    0.9375,    1.0000,    0.0625,
    1.0000,    1.0000,         0,
    1.0000,    0.9375,         0,
    1.0000,    0.8750,         0,
    1.0000,    0.8125,         0,
    1.0000,    0.7500,         0,
    1.0000,    0.6875,         0,
    1.0000,    0.6250,         0,
    1.0000,    0.5625,         0,
    1.0000,    0.5000,         0,
    1.0000,    0.4375,         0,
    1.0000,    0.3750,         0,
    1.0000,    0.3125,         0,
    1.0000,    0.2500,         0,
    1.0000,    0.1875,         0,
    1.0000,    0.1250,         0,
    1.0000,    0.0625,         0,
    1.0000,         0,         0,
    0.9375,         0,         0,
    0.8750,         0,         0,
    0.8125,         0,         0,
    0.7500,         0,         0,
    0.6875,         0,         0,
    0.6250,         0,         0,
    0.5625,         0,         0,
    0.5000,         0,         0]
    
rb_colormap_list_little = [ 0,         0,    0.5625,
         0,    0.1250,    1.0000,
         0,    0.6875,    1.0000,
    0.2500,    1.0000,    0.7500,
    0.8125,    1.0000,    0.1875,
    1.0000,    0.6250,         0,
    1.0000,    0.0625,         0,
    0.5000,         0,         0]


================================================
FILE: code/mesh_operations.py
================================================
#### Author : Wanquan Feng (University of Science and Technology of China)
#### Description : Some operations of the mesh/pc based on the numpy array
#### Data : 2021-10-16

import os
import sys
import numpy
import igl


#  off format
def read_off_(off_file_name):
    v,f,_ = igl.read_off(off_file_name)
    return v,f
def write_off_(off_file_name,v,face_=numpy.zeros((1))):
    fout = open(off_file_name,'w')
    fout.write('OFF\n')
    fout.write(str(v.shape[0])+' '+str(face_.shape[0])+' 0\n')
    for i in range(v.shape[0]):
        fout.write(str(v[i][0])+' '+str(v[i][1])+' '+str(v[i][2])+'\n')
    if face_.shape[0]<2:return None
    for i in range(face_.shape[0]):
        fout.write('3 '+str(face_[i][0])+' '+str(face_[i][1])+' '+str(face_[i][2])+'\n')
    fout.close()
    return None

# obj format
def write_obj_(obj_write_name,v,face_=numpy.zeros((1)),color_=numpy.zeros((1)),normal_=numpy.zeros((1))):
    f=open(obj_write_name,'w')
    vnum = v.shape[0]
    for vid in range(vnum):
        f.write('v '+str(v[vid][0])+' '+str(v[vid][1])+' '+str(v[vid][2]))
        if color_.shape[0]<vnum: f.write('\n')
        else:f.write(' '+str(color_[vid][0])+' '+str(color_[vid][1])+' '+str(color_[vid][2])+'\n')
        if normal_.shape[0]==vnum:
            f.write('vn '+str(normal_[vid][0])+' '+str(normal_[vid][1])+' '+str(normal_[vid][2])+'\n')
    if face_.shape[0]<2:
        f.close()
        return None
    fnum = face_.shape[0]
    for fid in range(fnum):
        f.write('f '+str(face_[fid][0]+1)+' '+str(face_[fid][1]+1)+' '+str(face_[fid][2]+1)+'\n')
    f.close()
    return None
def read_obj_(obj_write_name):
    v, _, n, f, _, _ = igl.read_obj(obj_write_name)
    return v, f, n

# xyz format
def write_xyz_(xyz_write_name,v,normal_=numpy.zeros((1))):
    f = open(xyz_write_name, 'w')
    vnum = v.shape[0]
    for i in range(vnum):
        f.write(str(v[i][0])+' '+str(v[i][1])+' '+str(v[i][2]))
        if normal_.shape[0]<vnum: f.write('\n')
        else:f.write(' '+str(normal_[i][0])+' '+str(normal_[i][1])+' '+str(normal_[i][2])+'\n')
    f.close()
    return None
def read_xyz_(xyz_name):
    v_ = []
    n_ = []
    ff = open(xyz_name)
    lines = ff.readlines()
    for i, aline in enumerate(lines):
        words = aline.split(' ')
        x,y,z = float(words[0]), float(words[1]), float(words[2])
        v_.append([x,y,z])
        if len(words)>=6:
            nx,ny,nz = float(words[3]), float(words[4]), float(words[5])
            n_.append([nx,ny,nz])
    v_ = numpy.array(v_).astype(numpy.float32)
    n_ = numpy.array(n_).astype(numpy.float32)
    if n_.shape[0] < v_.shape[0]:
        n_ = None
    return v_, n_

# format converting
def convert_obj_to_off_(obj_path_in, off_path_out):
    v,face_,_ = read_obj_(obj_path_in)
    write_off_(off_path_out, v, face_)
    return None
    
# normalize the points to sphere
def normalize_points_to_sphere_(v_in):
    v_out = v_in.copy()
    center = numpy.mean(v_out,axis=0,keepdims=True)
    v_out = v_out-center
    factor = numpy.sum(v_out*v_out, axis=-1, keepdims=True).max()**0.5
    v_out /= factor
    return v_out, center, factor

# normalize the points to sphere with given center and factor
def normalize_points_to_sphere_with_given_center_and_factor_(v_in, center, factor):
    v_out = v_in.copy()
    v_out = v_out-center
    v_out /= factor
    return v_out, center, factor
                


================================================
FILE: code/torch_tensor_functions.py
================================================
#### Author : Wanquan Feng (University of Science and Technology of China)
#### Description : Some operations of the point cloud based on the pytorch tensor
#### Data : 2021-10-16


import os
import sys
import torch
import numpy
import mesh_operations



def compute_sqrdis_map(points_x, points_y):
    ## The shape of the input and output ##
    # points_x : batchsize * M * 3
    # points_y : batchsize * N * 3
    # output   : batchsize * M * N
    thisbatchsize = points_x.size()[0]
    pn_x = points_x.size()[1]
    pn_y = points_y.size()[1]
    x_sqr = torch.sum(torch.mul(points_x, points_x), dim=-1).view(thisbatchsize, pn_x, 1).expand(-1,-1,pn_y)
    y_sqr = torch.sum(torch.mul(points_y, points_y), dim=-1).view(thisbatchsize, 1, pn_y).expand(-1,pn_x,-1)
    inner = torch.bmm(points_x, points_y.transpose(1,2))
    sqrdis = x_sqr + y_sqr - 2*inner
    return sqrdis

def draw_tensor_point_xyz_with_normal(save_path, torch_tensor_points, torch_tensor_normals=torch.ones([1])):
    ## The shape of the input ##
    # torch_tensor_points : M * 3
    # torch_tensor_normals (optional) : M * 3 
    if len(torch_tensor_points.size())!=2:
        print('The size of the point tensor should be 2. Exit here.')
        exit()
    if torch_tensor_points.size()[1]!=3:
        print('The dim of the point tensor is not correct. It should be (num_point, 3).')
        exit()
    numpy_points = torch_tensor_points.cpu().numpy()
    numpy_normals = torch_tensor_normals.cpu().numpy()
    mesh_operations.write_xyz_(save_path, numpy_points, numpy_normals)


def draw_tensor_point_xyz_with_normal_by_threshold(save_path, torch_tensor_points, torch_anchor, torch_tensor_normals=torch.ones([1]), threshold=0.95, ):
    ## The shape of the input ##
    # torch_tensor_points : M * 3
    # torch_tensor_normals (optional) : M * 3 
    # threshold : a float value < 1
    if len(torch_tensor_points.size())!=2:
        print('The size of the point tensor should be 2. Exit here.')
        exit()
    if torch_tensor_points.size()[1]!=3:
        print('The dim of the point tensor is not correct. It should be (num_point, 3).')
        exit()
    torch_tensor_points_norm = torch.sum(torch.mul(torch_tensor_points, torch_tensor_points), dim=1)

    numpy_points = torch_tensor_points.cpu().numpy()
    numpy_normals = torch_tensor_normals.cpu().numpy()
    mesh_operations.write_xyz_(save_path, numpy_points, numpy_normals)

def draw_tensor_point_obj_with_color(save_path, torch_tensor_points, torch_tensor_color=torch.ones([1])):
    ## The shape of the input ##
    # torch_tensor_points : M * 3
    # torch_tensor_color (optional) : M * 3 
    if len(torch_tensor_points.size())!=2:
        print('The size of the point tensor should be 2. Exit here.')
        exit()
    if torch_tensor_points.size()[1]!=3:
        print('The dim of the point tensor is not correct. It should be (num_point, 3).')
        exit()
    numpy_points = torch_tensor_points.cpu().numpy()
    numpy_color = torch_tensor_color.cpu().numpy()
    mesh_operations.write_obj_(save_path, numpy_points, color_=torch_tensor_color.cpu().numpy())


def draw_tensor_point_batch_xyz_with_normal(save_batch_path, torch_tensor_points_batch, torch_tensor_normals_batch=torch.ones([1])):
    ## The shape of the input ##
    # torch_tensor_points : B * M * 3
    # torch_tensor_normals (optional) : B * M * 3 
    if not os.path.exists(save_batch_path):os.mkdir(save_batch_path)
    thisbatchsize = len(torch_tensor_points_batch)
    for bi in range(thisbatchsize):
        bi_path = save_batch_path+'/'+str(bi)+'.xyz'
        torch_tensor_points = torch_tensor_points_batch[bi]
        if len(torch_tensor_normals_batch.size())==1: torch_tensor_normals = torch.ones([1])
        else:torch_tensor_normals = torch_tensor_normals_batch[bi]
        draw_tensor_point_xyz_with_normal(bi_path, torch_tensor_points, torch_tensor_normals)
    

def euler2rot(euler_angle):
    batch_size = euler_angle.shape[0]
    one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
    zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
    theta = euler_angle[:, 0].reshape(-1, 1, 1)
    phi = euler_angle[:, 1].reshape(-1, 1, 1)
    psi = euler_angle[:, 2].reshape(-1, 1, 1)
    rot_x = torch.cat((
        torch.cat((one, zero, zero), 1),
        torch.cat((zero, theta.cos(), theta.sin()), 1),
        torch.cat((zero, -theta.sin(), theta.cos()), 1),
    ), 2)
    rot_y = torch.cat((
        torch.cat((phi.cos(), zero, -phi.sin()), 1),
        torch.cat((zero, one, zero), 1),
        torch.cat((phi.sin(), zero, phi.cos()), 1),
    ), 2)
    rot_z = torch.cat((
        torch.cat((psi.cos(), -psi.sin(), zero), 1),
        torch.cat((psi.sin(), psi.cos(), zero), 1),
        torch.cat((zero, zero, one), 1)
    ), 2)
    return torch.bmm(rot_z, torch.bmm(rot_y, rot_x))



def get_neighbor_index(vertices: "(bs, vertice_num, 3)",  neighbor_num: int):
    # Return: (bs, vertice_num, neighbor_num)
    bs, v, _ = vertices.size()
    device = vertices.device
    inner = torch.bmm(vertices, vertices.transpose(1, 2)) #(bs, v, v)
    quadratic = torch.sum(vertices**2, dim= 2) #(bs, v)
    distance = inner * (-2) + quadratic.unsqueeze(1) + quadratic.unsqueeze(2)
    neighbor_index = torch.topk(distance, k= neighbor_num + 1, dim= -1, largest= False)[1]
    neighbor_index = neighbor_index[:, :, 1:]
    return neighbor_index


def indexing_neighbor(tensor: "(bs, vertice_num, dim)", index: "(bs, query_vertice_num, neighbor_num)" ):
    # Return: (bs, query_vertice_num, neighbor_num, dim)
    bs, v, n = index.size()
    id_0 = torch.arange(bs).view(-1, 1, 1)
    tensor_indexed = tensor[id_0, index]
    return tensor_indexed


def indexing_by_id(tensor: "(bs, vertice_num, dim)", index: "(bs, query_num, neighbor_num)" ):
    # Return: (bs, query_num, neighbor_num, dim)
    bs, v, n = index.size()
    id_0 = torch.arange(bs).view(-1, 1, 1)
    tensor_indexed = tensor[id_0, index]
    return tensor_indexed

================================================
FILE: model/conpu_v6/chamfer_distance/__init__.py
================================================
from .chamfer_distance import ChamferDistance


================================================
FILE: model/conpu_v6/chamfer_distance/chamfer_distance.cpp
================================================
#include <torch/torch.h>

// CUDA forward declarations
int ChamferDistanceKernelLauncher(
    const int b, const int n,
    const float* xyz,
    const int m,
    const float* xyz2,
    float* result,
    int* result_i,
    float* result2,
    int* result2_i);

int ChamferDistanceGradKernelLauncher(
    const int b, const int n,
    const float* xyz1,
    const int m,
    const float* xyz2,
    const float* grad_dist1,
    const int* idx1,
    const float* grad_dist2,
    const int* idx2,
    float* grad_xyz1,
    float* grad_xyz2);


void chamfer_distance_forward_cuda(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor dist1, 
    const at::Tensor dist2, 
    const at::Tensor idx1, 
    const at::Tensor idx2) 
{
//    std::cout<<"here"<<std::endl;
//    std::cout<<xyz1.size(0)<<std::endl;
//    std::cout<<xyz1.size(1)<<std::endl;
//    std::cout<<dist2.device()<<std::endl;
    
    ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
                                            xyz2.size(1), xyz2.data<float>(),
                                            dist1.data<float>(), idx1.data<int>(),
                                            dist2.data<float>(), idx2.data<int>());
}

void chamfer_distance_backward_cuda(
    const at::Tensor xyz1,
    const at::Tensor xyz2, 
    at::Tensor gradxyz1, 
    at::Tensor gradxyz2, 
    at::Tensor graddist1, 
    at::Tensor graddist2, 
    at::Tensor idx1, 
    at::Tensor idx2)
{
    ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
                                           xyz2.size(1), xyz2.data<float>(),
                                           graddist1.data<float>(), idx1.data<int>(),
                                           graddist2.data<float>(), idx2.data<int>(),
                                           gradxyz1.data<float>(), gradxyz2.data<float>());
}


void nnsearch(
    const int b, const int n, const int m,
    const float* xyz1,
    const float* xyz2,
    float* dist,
    int* idx)
{
    for (int i = 0; i < b; i++) {
        for (int j = 0; j < n; j++) {
            const float x1 = xyz1[(i*n+j)*3+0];
            const float y1 = xyz1[(i*n+j)*3+1];
            const float z1 = xyz1[(i*n+j)*3+2];
            double best = 0;
            int besti = 0;
            for (int k = 0; k < m; k++) {
                const float x2 = xyz2[(i*m+k)*3+0] - x1;
                const float y2 = xyz2[(i*m+k)*3+1] - y1;
                const float z2 = xyz2[(i*m+k)*3+2] - z1;
                const double d=x2*x2+y2*y2+z2*z2;
                if (k==0 || d < best){
                    best = d;
                    besti = k;
                }
            }
            dist[i*n+j] = best;
            idx[i*n+j] = besti;
        }
    }
}


void chamfer_distance_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor dist1, 
    const at::Tensor dist2, 
    const at::Tensor idx1, 
    const at::Tensor idx2) 
{
    const int batchsize = xyz1.size(0);
    const int n = xyz1.size(1);
    const int m = xyz2.size(1);

    const float* xyz1_data = xyz1.data<float>();
    const float* xyz2_data = xyz2.data<float>();
    float* dist1_data = dist1.data<float>();
    float* dist2_data = dist2.data<float>();
    int* idx1_data = idx1.data<int>();
    int* idx2_data = idx2.data<int>();

    nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
    nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
}


void chamfer_distance_backward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    at::Tensor gradxyz1, 
    at::Tensor gradxyz2, 
    at::Tensor graddist1, 
    at::Tensor graddist2, 
    at::Tensor idx1, 
    at::Tensor idx2) 
{
    const int b = xyz1.size(0);
    const int n = xyz1.size(1);
    const int m = xyz2.size(1);

    const float* xyz1_data = xyz1.data<float>();
    const float* xyz2_data = xyz2.data<float>();
    float* gradxyz1_data = gradxyz1.data<float>();
    float* gradxyz2_data = gradxyz2.data<float>();
    float* graddist1_data = graddist1.data<float>();
    float* graddist2_data = graddist2.data<float>();
    const int* idx1_data = idx1.data<int>();
    const int* idx2_data = idx2.data<int>();

    for (int i = 0; i < b*n*3; i++)
        gradxyz1_data[i] = 0;
    for (int i = 0; i < b*m*3; i++)
        gradxyz2_data[i] = 0;
    for (int i = 0;i < b; i++) {
        for (int j = 0; j < n; j++) {
            const float x1 = xyz1_data[(i*n+j)*3+0];
            const float y1 = xyz1_data[(i*n+j)*3+1];
            const float z1 = xyz1_data[(i*n+j)*3+2];
            const int j2 = idx1_data[i*n+j];

            const float x2 = xyz2_data[(i*m+j2)*3+0];
            const float y2 = xyz2_data[(i*m+j2)*3+1];
            const float z2 = xyz2_data[(i*m+j2)*3+2];
            const float g = graddist1_data[i*n+j]*2;

            gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
            gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
            gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
            gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
            gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
            gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
        }
        for (int j = 0; j < m; j++) {
            const float x1 = xyz2_data[(i*m+j)*3+0];
            const float y1 = xyz2_data[(i*m+j)*3+1];
            const float z1 = xyz2_data[(i*m+j)*3+2];
            const int j2 = idx2_data[i*m+j];
            const float x2 = xyz1_data[(i*n+j2)*3+0];
            const float y2 = xyz1_data[(i*n+j2)*3+1];
            const float z2 = xyz1_data[(i*n+j2)*3+2];
            const float g = graddist2_data[i*m+j]*2;
            gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
            gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
            gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
            gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
            gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
            gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
        }
    }
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
    m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
    m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
    m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
}


================================================
FILE: model/conpu_v6/chamfer_distance/chamfer_distance.cu
================================================
#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

__global__ 
void ChamferDistanceKernel(
	int b,
	int n,
	const float* xyz,
	int m,
	const float* xyz2,
	float* result,
	int* result_i)
{
	const int batch=512;
	__shared__ float buf[batch*3];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int k2=0;k2<m;k2+=batch){
			int end_k=min(m,k2+batch)-k2;
			for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
				buf[j]=xyz2[(i*m+k2)*3+j];
			}
			__syncthreads();
			for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
				float x1=xyz[(i*n+j)*3+0];
				float y1=xyz[(i*n+j)*3+1];
				float z1=xyz[(i*n+j)*3+2];
				int best_i=0;
				float best=0;
				int end_ka=end_k-(end_k&3);
				if (end_ka==batch){
					for (int k=0;k<batch;k+=4){
						{
							float x2=buf[k*3+0]-x1;
							float y2=buf[k*3+1]-y1;
							float z2=buf[k*3+2]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (k==0 || d<best){
								best=d;
								best_i=k+k2;
							}
						}
						{
							float x2=buf[k*3+3]-x1;
							float y2=buf[k*3+4]-y1;
							float z2=buf[k*3+5]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+1;
							}
						}
						{
							float x2=buf[k*3+6]-x1;
							float y2=buf[k*3+7]-y1;
							float z2=buf[k*3+8]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+2;
							}
						}
						{
							float x2=buf[k*3+9]-x1;
							float y2=buf[k*3+10]-y1;
							float z2=buf[k*3+11]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+3;
							}
						}
					}
				}else{
					for (int k=0;k<end_ka;k+=4){
						{
							float x2=buf[k*3+0]-x1;
							float y2=buf[k*3+1]-y1;
							float z2=buf[k*3+2]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (k==0 || d<best){
								best=d;
								best_i=k+k2;
							}
						}
						{
							float x2=buf[k*3+3]-x1;
							float y2=buf[k*3+4]-y1;
							float z2=buf[k*3+5]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+1;
							}
						}
						{
							float x2=buf[k*3+6]-x1;
							float y2=buf[k*3+7]-y1;
							float z2=buf[k*3+8]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+2;
							}
						}
						{
							float x2=buf[k*3+9]-x1;
							float y2=buf[k*3+10]-y1;
							float z2=buf[k*3+11]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+3;
							}
						}
					}
				}
				for (int k=end_ka;k<end_k;k++){
					float x2=buf[k*3+0]-x1;
					float y2=buf[k*3+1]-y1;
					float z2=buf[k*3+2]-z1;
					float d=x2*x2+y2*y2+z2*z2;
					if (k==0 || d<best){
						best=d;
						best_i=k+k2;
					}
				}
				if (k2==0 || result[(i*n+j)]>best){
					result[(i*n+j)]=best;
					result_i[(i*n+j)]=best_i;
				}
			}
			__syncthreads();
		}
	}
}

void ChamferDistanceKernelLauncher(
    const int b, const int n,
    const float* xyz,
    const int m,
    const float* xyz2,
    float* result,
    int* result_i,
    float* result2,
    int* result2_i)
{
	ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, n, xyz, m, xyz2, result, result_i);
	ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, m, xyz2, n, xyz, result2, result2_i);

	cudaError_t err = cudaGetLastError();
	if (err != cudaSuccess)
	    printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
}


__global__ 
void ChamferDistanceGradKernel(
	int b, int n,
	const float* xyz1,
	int m,
	const float* xyz2,
	const float* grad_dist1,
	const int* idx1,
	float* grad_xyz1,
	float* grad_xyz2)
{
	for (int i = blockIdx.x; i<b; i += gridDim.x) {
		for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x*gridDim.y) {
			float x1=xyz1[(i*n+j)*3+0];
			float y1=xyz1[(i*n+j)*3+1];
			float z1=xyz1[(i*n+j)*3+2];
			int j2=idx1[i*n+j];
			float x2=xyz2[(i*m+j2)*3+0];
			float y2=xyz2[(i*m+j2)*3+1];
			float z2=xyz2[(i*m+j2)*3+2];
			float g=grad_dist1[i*n+j]*2;
			atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
			atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
			atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
		}
	}
}

void ChamferDistanceGradKernelLauncher(
    const int b, const int n,
    const float* xyz1,
    const int m,
    const float* xyz2,
    const float* grad_dist1,
    const int* idx1,
    const float* grad_dist2,
    const int* idx2,
    float* grad_xyz1,
    float* grad_xyz2)
{
	cudaMemset(grad_xyz1, 0, b*n*3*4);
	cudaMemset(grad_xyz2, 0, b*m*3*4);
	ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
	ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);

	cudaError_t err = cudaGetLastError();
  	if (err != cudaSuccess)
	    printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
}


================================================
FILE: model/conpu_v6/chamfer_distance/chamfer_distance.py
================================================

import torch

from torch.utils.cpp_extension import load
cd = load(name="cd",
          sources=["chamfer_distance/chamfer_distance.cpp",
                   "chamfer_distance/chamfer_distance.cu"],
                   extra_cflags=['-g'])

class ChamferDistanceFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz1, xyz2):
        batchsize, n, _ = xyz1.size()
        _, m, _ = xyz2.size()
        xyz1 = xyz1.contiguous()
        xyz2 = xyz2.contiguous()
        dist1 = torch.zeros(batchsize, n)
        dist2 = torch.zeros(batchsize, m)

        idx1 = torch.zeros(batchsize, n, dtype=torch.int)
        idx2 = torch.zeros(batchsize, m, dtype=torch.int)
        if not xyz1.is_cuda:
            cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
        else:
            dist1 = dist1.to(xyz1.device)
            dist2 = dist2.to(xyz1.device)
            idx1 = idx1.to(xyz1.device)
            idx2 = idx2.to(xyz1.device)
            
            cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
        ctx.save_for_backward(xyz1, xyz2, idx1, idx2)

        return dist1, dist2

    @staticmethod
    def backward(ctx, graddist1, graddist2):
        xyz1, xyz2, idx1, idx2 = ctx.saved_tensors

        graddist1 = graddist1.contiguous()
        graddist2 = graddist2.contiguous()

        gradxyz1 = torch.zeros(xyz1.size())
        gradxyz2 = torch.zeros(xyz2.size())

        if not graddist1.is_cuda:
            cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
        else:
            gradxyz1 = gradxyz1.to(graddist1.device)
            gradxyz2 = gradxyz2.to(graddist1.device)
            cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)

        return gradxyz1, gradxyz2


class ChamferDistance(torch.nn.Module):
    def forward(self, xyz1, xyz2):
        return ChamferDistanceFunction.apply(xyz1, xyz2)


================================================
FILE: model/conpu_v6/chamfer_distance/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='chamferdis',
    ext_modules=[
        CUDAExtension('chamferdis', [
            'chamfer_distance.cpp',
            'chamfer_distance.cu',
        ],
                extra_compile_args=['-g']),
    ],
    cmdclass={
        'build_ext': BuildExtension
    })


================================================
FILE: model/conpu_v6/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.ndimage
import sys
import time
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
sys.path.append('../../code/')
import cv2 as cv
from PIL import Image
from chamfer_distance import ChamferDistance
chamfer_dist = ChamferDistance()
import glob
import trimesh
import random
import numpy as np
import math
from math import ceil
import time
import cv2
from PIL import Image
#from options import TestOptions
#import trimesh
import struct
import pickle
from pointnet2 import pointnet2_utils as pn2_utils


import torch_tensor_functions

class Loss(nn.Module):
    def __init__(self, args):
        super(Loss, self).__init__()
        self.args = args
        
    def loss_on_cd(self, deformation_p, p1):
        thisbatchsize = deformation_p.size()[0]
        output = 0
        dist1, dist2 = chamfer_dist(deformation_p, p1)
        output += (torch.sum(dist1) + torch.sum(dist2))*0.5
        return output/thisbatchsize
    
    def loss_on_proj(self, p0, p1):
        # p0 : B, M, 3
        # p1 : B, N, 3
        thisbatchsize = p0.size()[0]
        output = 0
        dis_map = torch_tensor_functions.compute_sqrdis_map(p0, p1)   # B, M, N

        neighbour_id_01 = torch.topk(dis_map, k=5, dim=-1, largest= False)[1]
        neighbour_dis_01 = torch.topk(dis_map, k=5, dim=-1, largest= False)[0]
        neighbour_id_01 = neighbour_id_01[:,:,1:]
        neighbour_coor_01 = torch_tensor_functions.indexing_neighbor(p1, neighbour_id_01)
        neighbour_dis_01 = neighbour_dis_01[:,:,1:]
        neighbour_weight_01 = neighbour_dis_01.detach() * 1000
        neighbour_weight_01 = torch.exp(-1*neighbour_weight_01)
        neighbour_weight_01 = neighbour_weight_01/(torch.sum(neighbour_weight_01, dim=-1, keepdim=True)+0.00001)
        dis_01 = p0.view(thisbatchsize,-1,1,3) - neighbour_coor_01
        dis_01 = torch.sum(torch.mul(dis_01, dis_01), dim=-1, keepdim=False)
        pro_dis_01 = torch.mul(neighbour_weight_01, dis_01)
        output += 0.5 * torch.sum(pro_dis_01)

        neighbour_id_10 = torch.topk(dis_map, k=5, dim=1, largest= False)[1].transpose(2,1)
        neighbour_dis_10 = torch.topk(dis_map, k=5, dim=1, largest= False)[0].transpose(2,1)
        neighbour_id_10 = neighbour_id_10[:,:,1:]
        neighbour_coor_10 = torch_tensor_functions.indexing_neighbor(p0, neighbour_id_10)
        neighbour_dis_10 = neighbour_dis_10[:,:,1:]
        neighbour_weight_10 = neighbour_dis_10.detach() * 1000
        neighbour_weight_10 = torch.exp(-1*neighbour_weight_10)
        neighbour_weight_10 = neighbour_weight_10/(torch.sum(neighbour_weight_10, dim=-1, keepdim=True)+0.00001)
        dis_10 = p1.view(thisbatchsize,-1,1,3) - neighbour_coor_10
        dis_10 = torch.sum(torch.mul(dis_10, dis_10), dim=-1, keepdim=False)
        pro_dis_10 = torch.mul(neighbour_weight_10, dis_10)
        output += 0.5 * torch.sum(pro_dis_10)

        return output/thisbatchsize

    
    def loss_on_normal(self, p0, p1, n0, n1):
        # p0 : B, M, 3 ; n0 : B, M, 3
        # p1 : B, N, 3 ; n1 : B, N, 3
        thisbatchsize = p0.size()[0]
        output = 0
        dis_map = torch_tensor_functions.compute_sqrdis_map(p0, p1)   # B, M, N

        neighbour_id_01 = torch.topk(dis_map, k=5, dim=-1, largest= False)[1]
        neighbour_dis_01 = torch.topk(dis_map, k=5, dim=-1, largest= False)[0]
        neighbour_id_01 = neighbour_id_01[:,:,1:]
        neighbour_normal_01 = torch_tensor_functions.indexing_neighbor(n1, neighbour_id_01)
        neighbour_dis_01 = neighbour_dis_01[:,:,1:]
        neighbour_weight_01 = neighbour_dis_01.detach() * 1000
        neighbour_weight_01 = torch.exp(-1*neighbour_weight_01)
        neighbour_weight_01 = neighbour_weight_01/(torch.sum(neighbour_weight_01,   dim=-1, keepdim=True)+0.00001)
        dis_01 = n0.view(thisbatchsize,-1,1,3) - neighbour_normal_01
        dis_01 = torch.sum(torch.mul(dis_01, dis_01), dim=-1, keepdim=False)
        dis_01_ = n0.view(thisbatchsize,-1,1,3) + neighbour_normal_01
        dis_01_ = torch.sum(torch.mul(dis_01_, dis_01_), dim=-1, keepdim=False)
        bar_ = torch.sign(dis_01 - dis_01_)
        dis_01_min = torch.mul((bar_+1)*0.5, dis_01_) + torch.mul((1-bar_)*0.5, dis_01)
        dis_01_min = torch.mul(neighbour_weight_01, dis_01_min)
        output += 0.5 * torch.sum(dis_01_min)

        return output/thisbatchsize
    
    def loss_on_reg(self, gen_points_batch, train_points_sparse_batch):
        thisbatchsize = gen_points_batch.size()[0]
        output = 0
        up_ratio_here = gen_points_batch.size()[1]//train_points_sparse_batch.size()[1]
        gen_points_batch_ = gen_points_batch.view(thisbatchsize,-1,up_ratio_here,3)
        train_points_sparse_batch_ = train_points_sparse_batch.view(thisbatchsize,-1,1,3)
        dis = train_points_sparse_batch_ - gen_points_batch_
        squdis = torch.sum(torch.mul(dis,dis),dim=-1,keepdim=True)
        squdis_bar = squdis.detach()*0+0.04
        squdis_sign = torch.sign(squdis.detach() - squdis_bar)*0.5+1
        squdis = torch.mul(squdis,squdis_sign)
        output += torch.sum(squdis)
        return output/thisbatchsize
    
    def loss_on_arap(self, gen_points_batch, uv_sampling_coors):
        thisbatchsize = gen_points_batch.size()[0]
        output = 0
        gen_points_batch_ = gen_points_batch.reshape(thisbatchsize*self.args.num_point, -1 ,3)
        uv_sampling_coors_ = uv_sampling_coors.reshape(thisbatchsize*self.args.num_point, -1 ,2).detach()
        uv_sampling_coors_ = torch.cat((uv_sampling_coors_, uv_sampling_coors_[:,:,:1]),dim=-1)
        uv_sampling_coors_[:,:,2:]*=0
        neighbour_indexes = torch_tensor_functions.get_neighbor_index(uv_sampling_coors_, 4) 
        uv_neibour_points_ = torch_tensor_functions.indexing_neighbor(uv_sampling_coors_, neighbour_indexes)
        gen_neibour_points_ = torch_tensor_functions.indexing_neighbor(gen_points_batch_, neighbour_indexes)
        uv_dis = uv_neibour_points_ - uv_sampling_coors_.view(thisbatchsize*self.args.num_point, -1 ,1, 3)
        gen_dis = gen_neibour_points_ - gen_points_batch_.view(thisbatchsize*self.args.num_point, -1 ,1, 3)
        uv_squ_dis = torch.sqrt( torch.sum(torch.mul(uv_dis, uv_dis),dim=-1) + 0.00000001 )
        gen_squ_dis = torch.sqrt( torch.sum(torch.mul(gen_dis, gen_dis),dim=-1) + 0.00000001 )
        uv_sum_dis = torch.sum(uv_squ_dis)
        gen_sum_dis = torch.sum(gen_squ_dis).detach()
        uv_squ_dis *= gen_sum_dis / uv_sum_dis
        delta = uv_squ_dis - gen_squ_dis
        output += torch.sum(torch.mul(delta, delta))
        return output/thisbatchsize

    def loss_on_overlap(self, gen_points_batch, train_points_sparse_batch):
        thisbatchsize = gen_points_batch.size()[0]
        output = 0
        gen_points_batch_ = gen_points_batch.reshape(thisbatchsize*self.args.num_point, -1 ,3)
        neighbour_indexes = torch_tensor_functions.get_neighbor_index(train_points_sparse_batch, 6) 
        sparse_neibour_points_ = torch_tensor_functions.indexing_neighbor(train_points_sparse_batch, neighbour_indexes)
        sparse_neibour_points_ = sparse_neibour_points_.reshape(thisbatchsize*self.args.num_point, -1, 3)
        cross_dis = torch_tensor_functions.compute_sqrdis_map(sparse_neibour_points_, gen_points_batch_)
        dis = torch.sum(torch.min(cross_dis,dim=-1)[0])
        output += dis
        return output/thisbatchsize


    def loss_on_ndirection(self, gen_points_batch, uv_sampling_coors, gen_normals_batch):
        thisbatchsize = gen_points_batch.size()[0]
        output = 0
        # gen_points_batch_ = gen_points_batch.reshape(thisbatchsize*self.args.num_point, -1 ,3)
        gen_normals_batch_ = gen_normals_batch.reshape(thisbatchsize*self.args.num_point, -1 ,3)
        uv_sampling_coors_ = uv_sampling_coors.reshape(thisbatchsize*self.args.num_point, -1 ,2).detach()
        uv_sampling_coors_ = torch.cat((uv_sampling_coors_, uv_sampling_coors_[:,:,:1]),dim=-1)
        uv_sampling_coors_[:,:,2:]*=0
        neighbour_indexes = torch_tensor_functions.get_neighbor_index(uv_sampling_coors_, 4) 
        uv_neibour_points_ = torch_tensor_functions.indexing_neighbor(uv_sampling_coors_, neighbour_indexes)
        # gen_neibour_points_ = torch_tensor_functions.indexing_neighbor(gen_points_batch_, neighbour_indexes)
        gen_neibour_normals_ = torch_tensor_functions.indexing_neighbor(gen_normals_batch_, neighbour_indexes)
        gen_normals_batch_ = gen_normals_batch_.view(thisbatchsize*self.args.num_point, -1 ,1, 3)
        gen_neibour_normals_delta_ = gen_neibour_normals_ - gen_normals_batch_
        gen_neibour_normals_delta_squ = torch.mul(gen_neibour_normals_delta_, gen_neibour_normals_delta_)

        normals_delta_squ_bar = gen_neibour_normals_delta_squ.detach()*0+1
        normals_delta_squ_sign = torch.sign(gen_neibour_normals_delta_squ.detach() - normals_delta_squ_bar)*0.5+1
        gen_neibour_normals_delta_squ = torch.mul(gen_neibour_normals_delta_squ, normals_delta_squ_sign)

        output += torch.sum(gen_neibour_normals_delta_squ)
        
        return output/thisbatchsize


    
    def forward(self, gen_points_batch, gen_normals_batch, uv_sampling_coors, train_points_sparse_batch, train_normals_sparse_batch, train_points_dense_batch, train_normals_dense_batch):
        thisbatchsize = gen_points_batch.size()[0]
        loss = torch.mean(torch.zeros((1),dtype = torch.float, device=gen_points_batch.device))
        zero_tensor = torch.mean(torch.zeros((1),dtype = torch.float, device=gen_points_batch.device))
        loss_stages=[]
        
        if self.args.weight_cd > 0:
            # L^{cd}  # n*3, n*3
            loss_cd = 0 
            loss_cd += self.loss_on_cd(gen_points_batch, train_points_dense_batch)
            loss += loss_cd * self.args.weight_cd
            loss_stages.append(loss_cd)
        else:
            loss_stages.append(zero_tensor)    

        if self.args.weight_reg > 0:
            # L^{reg}  # n*3, n*3
            loss_reg = 0 
            loss_reg += self.loss_on_reg(gen_points_batch, train_points_sparse_batch)
            loss += loss_reg * self.args.weight_reg
            loss_stages.append(loss_reg)
        else:
            loss_stages.append(zero_tensor) 

        if self.args.weight_arap > 0:
            # L^{arap}  # 
            loss_arap = 0 
            loss_arap += self.loss_on_arap(gen_points_batch, uv_sampling_coors)
            loss += loss_arap * self.args.weight_arap
            loss_stages.append(loss_arap)
        else:
            loss_stages.append(zero_tensor) 


        if self.args.weight_overlap > 0:
            # L^{overlap}  # 
            loss_overlap = 0 
            loss_overlap += self.loss_on_overlap(gen_points_batch, train_points_sparse_batch)
            loss += loss_overlap * self.args.weight_overlap
            loss_stages.append(loss_overlap)
        else:
            loss_stages.append(zero_tensor) 
           
        
        if self.args.weight_proj > 0:
            # L^{proj}  # 
            loss_proj = 0 
            loss_proj += self.loss_on_proj(gen_points_batch, train_points_dense_batch)
            loss += loss_proj * self.args.weight_proj
            loss_stages.append(loss_proj)
        else:
            loss_stages.append(zero_tensor) 

        if self.args.weight_normal > 0:
            # L^{normal}  # 
            loss_normal = 0 
            loss_normal += self.loss_on_normal(gen_points_batch, train_points_dense_batch, gen_normals_batch, train_normals_dense_batch)
            loss += loss_normal * self.args.weight_normal
            loss_stages.append(loss_normal)
        else:
            loss_stages.append(zero_tensor) 


        if self.args.weight_ndirection > 0:
            # L^{ndirection}  # 
            loss_ndirection = 0 
            loss_ndirection += self.loss_on_ndirection(gen_points_batch, uv_sampling_coors, gen_normals_batch)
            loss += loss_ndirection * self.args.weight_ndirection
            loss_stages.append(loss_ndirection)
        else:
            loss_stages.append(zero_tensor) 
           
            
        return loss, loss_stages


================================================
FILE: model/conpu_v6/network.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from  torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from torch.autograd import grad
import math
import numpy as np
import torch.nn.init as init
import struct
import os
import sys
import glob
import h5py
import copy
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../code')
import igl
from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch
import torch_tensor_functions
import mesh_operations
from pointnet2 import pointnet2_utils as pn2_utils
#from chamfer_distance import ChamferDistance
#chamfer_dist = ChamferDistance()


######## TODO: START PART: FUNCTIONS ABOUT DGCNN. IT IS USED AS THE FEATURE EXTRACTOR IN OUR FRAMEWORK. ########
#### The DGCNN network ####
class DGCNN_multi_knn_c5(nn.Module):
    def __init__(self, emb_dims=512, args=None):
        super(DGCNN_multi_knn_c5, self).__init__()
        self.args = args
        self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False)
        init.xavier_normal_(self.conv1.weight, gain=1.0)
        self.conv2 = nn.Conv2d(64*2, 64, kernel_size=1, bias=False)
        init.xavier_normal_(self.conv2.weight, gain=1.0)
        self.conv3 = nn.Conv2d(64*2, 128, kernel_size=1, bias=False)
        init.xavier_normal_(self.conv3.weight, gain=1.0)
        self.conv4 = nn.Conv2d(128*2, 256, kernel_size=1, bias=False)
        init.xavier_normal_(self.conv4.weight, gain=1.0)
        self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
        init.xavier_normal_(self.conv5.weight, gain=1.0)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(emb_dims)
    def forward(self, x, if_relu_atlast = False):
        batch_size, num_dims, num_points = x.size()
        x = get_graph_feature(x) # This sub model get the graph-based features for the following 2D convs
        # The x is similar with 2D image
        if self.args.if_bn == True: x = F.relu(self.bn1(self.conv1(x)))
        else: x = F.relu(self.conv1(x))
        x1 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x1)
        if self.args.if_bn == True: x = F.relu(self.bn2(self.conv2(x))) 
        else: x = F.relu(self.conv2(x))
        x2 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x2)
        if self.args.if_bn == True: x = F.relu(self.bn3(self.conv3(x))) 
        else: x = F.relu(self.conv3(x))
        x3 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x3)
        if self.args.if_bn == True: x = F.relu(self.bn4(self.conv4(x))) 
        else: x = F.relu(self.conv4(x))
        x4 = x.max(dim=-1, keepdim=False)[0]
        x = torch.cat((x1, x2, x3, x4), dim=1).unsqueeze(3)
        if if_relu_atlast == False:
            return torch.tanh(self.conv5(x)).view(batch_size, -1, num_points)
        x = F.relu(self.conv5(x)).view(batch_size, -1, num_points)
        return x
#### The knn function used in graph_feature ####
def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx
#### The edge_feature used in DGCNN ####
def get_graph_feature(x, k=4):
    idx = knn(x, k=k)  # (batch_size, num_points, k)
    batch_size, num_points, _ = idx.size()
    device = torch.device('cuda')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2,1).contiguous()  # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
    return feature
######## TODO: END PART: FUNCTIONS ABOUT DGCNN. IT IS USED AS THE FEATURE EXTRACTOR IN OUR FRAMEWORK. ########

######## TODO: START PART: NEURAL IMPLICIT FUNCTION, MLP with ReLU. ########
#### Construct the neural implicit function. ####
class MLPNet_relu(torch.nn.Module):
    """ Multi-layer perception.
        [B, Cin, N] -> [B, Cout, N] or
        [B, Cin] -> [B, Cout]
    """
    def __init__(self, nch_input, nch_layers, b_shared=True, bn_momentum=0.1, dropout=0.0, if_bn = True):
        super().__init__()
        list_layers = mlp_layers_relu(nch_input, nch_layers, b_shared, bn_momentum, dropout, if_bn)
        self.layers = torch.nn.Sequential(*list_layers)
    def forward(self, inp):
        out = self.layers(inp)
        return out
#### Construct the mlp_layers of the neural implicit function. ####
def mlp_layers_relu(nch_input, nch_layers, b_shared=True, bn_momentum=0.1, dropout=0.0, if_bn=True):
    """ [B, Cin, N] -> [B, Cout, N] or
        [B, Cin] -> [B, Cout]
    """
    layers = []
    last = nch_input
    for i, outp in enumerate(nch_layers):
        if b_shared:
            weights = torch.nn.Conv1d(last, outp, 1)
            init.xavier_normal_(weights.weight, gain=1.0)
            # if i==0: init.uniform_(weights.weight, a=-(6/last)**0.5*30, b=(6/last)**0.5*30)
            # else: init.uniform_(weights.weight, a=-(6/last)**0.5, b=(6/last)**0.5)
        else:
            weights = torch.nn.Linear(last, outp)
            init.xavier_normal_(weights.weight, gain=1.0)
        layers.append(weights)
        if if_bn==True:
            layers.append(torch.nn.BatchNorm1d(outp, momentum=bn_momentum))
        layers.append(torch.nn.ReLU())
        # layers.append(Sine())
        if b_shared == False and dropout > 0.0:
            layers.append(torch.nn.Dropout(dropout))
        last = outp
    return layers
######## TODO: END PART: NEURAL IMPLICIT FUNCTION, MLP with ReLU. ########


######## TODO: START PART: NEURAL IMPLICIT FUNCTION, MLP with SIREN. ########
#### Construct the neural implicit function. ####
class MLPNet(torch.nn.Module):
    """ Multi-layer perception.
        [B, Cin, N] -> [B, Cout, N] or
        [B, Cin] -> [B, Cout]
    """
    def __init__(self, nch_input, nch_layers, b_shared=True, bn_momentum=0.1, dropout=0.0, if_bn = True):
        super().__init__()
        list_layers = mlp_layers(nch_input, nch_layers, b_shared, bn_momentum, dropout, if_bn)
        self.layers = torch.nn.Sequential(*list_layers)
    def forward(self, inp):
        out = self.layers(inp)
        return out
#### Construct the mlp_layers of the neural implicit function. ####
def mlp_layers(nch_input, nch_layers, b_shared=True, bn_momentum=0.1, dropout=0.0, if_bn=True):
    """ [B, Cin, N] -> [B, Cout, N] or
        [B, Cin] -> [B, Cout]
    """
    layers = []
    last = nch_input
    for i, outp in enumerate(nch_layers):
        if b_shared:
            weights = torch.nn.Conv1d(last, outp, 1)
            #init.xavier_normal_(weights.weight, gain=1.0)
            if i==0: init.uniform_(weights.weight, a=-(6/last)**0.5*30, b=(6/last)**0.5*30)
            else: init.uniform_(weights.weight, a=-(6/last)**0.5, b=(6/last)**0.5)
        else:
            weights = torch.nn.Linear(last, outp)
            init.xavier_normal_(weights.weight, gain=1.0)
        layers.append(weights)
        if if_bn==True:
            layers.append(torch.nn.BatchNorm1d(outp, momentum=bn_momentum))
        #layers.append(torch.nn.ReLU())
        layers.append(Sine())
        if b_shared == False and dropout > 0.0:
            layers.append(torch.nn.Dropout(dropout))
        last = outp
    return layers
#### The nn.Moudle Sine, as the activation function, used in the nearal implicit function. ####
class Sine(nn.Module):
    def __init(self):
        super().__init__()
    def forward(self, input):
        return torch.sin(input)
######## TODO: END PART: NEURAL IMPLICIT FUNCTION, MLP with SIREN. ########


######## TODO: START PART: OUR OWN NETWORK ########
#### The main network ####
class Net_conpu_v7(nn.Module):
    def __init__(self, args):
        super(Net_conpu_v7, self).__init__()
        # basic settings
        self.args = args # the args
        self.emb_dims = args.emb_dims # the dim of the embedded feture
        self.up_ratio = -1 # the upsampling factor
        self.over_sampling_up_ratio = -1 # the scale of over-sampling
        self.mlp_fitting_str = self.args.mlp_fitting_str
        self.mlp_fitting = convert_str_2_list(self.mlp_fitting_str) # the channels of the layers in the MLP
        ######################## START PART : LAYERS #########################
        ## 1. The point-wise feature extraction, DGCNN.
        self.emb_nn_sparse = DGCNN_multi_knn_c5(emb_dims=self.emb_dims, args=self.args) # the DGCNN backbone, which is shared by all the local parts
        ## 2. The Neural Field, MLP.
        if self.args.if_use_siren==True: self.fitting_mlp = MLPNet(2*self.emb_dims+(self.args.pe_out_L*4+2), self.mlp_fitting, b_shared=True, if_bn =False).layers
        else: self.fitting_mlp = MLPNet_relu(2*self.emb_dims+(self.args.pe_out_L*4+2), self.mlp_fitting, b_shared=True, if_bn =False).layers   
        self.reconstruct_out_p = torch.nn.Conv1d(self.mlp_fitting[-1], 3, 1)
        init.xavier_normal_(self.reconstruct_out_p.weight, gain=1.0)
        self.convert_feature_to_point_2to3 = torch.nn.Sequential(self.fitting_mlp, self.reconstruct_out_p)   # the Neural Field Fuction (MLP) 
        ######################## END PART : LAYERS #########################
    
    def forward(self, points_sparse):
        # The input [points_sparse] should be in shape (thisbatchsize, self.args.num_point, 3)
        thisbatchsize = points_sparse.size()[0]
        neighbour_indexes_ = torch_tensor_functions.get_neighbor_index(points_sparse, self.args.feature_unfolding_nei_num)   # thisbatchsize, self.args.num_point, neighbor_num
        ######### How to set the uv_sampling_coors ?
        #### We DON'T NEED TO give the network the uv_sampling_coors, it would be computed automatically. And the up_ratio should be training_up_ratio/testing_up_ratio, depending on self.training.
        uv_sampling_coors=torch.ones([1]).float().cuda()
        if self.training == True : self.up_ratio = self.args.training_up_ratio
        else : self.up_ratio = self.args.testing_up_ratio
        self.over_sampling_up_ratio = int(self.up_ratio * self.args.over_sampling_scale)
        if self.args.if_fix_sample == True: uv_sampling_coors = fix_sample(thisbatchsize, self.args.num_point, self.over_sampling_up_ratio)
        else: 
            uv_sampling_coors_1 = uniform_random_sample(thisbatchsize, self.args.num_point, self.over_sampling_up_ratio-4)
            uv_sampling_coors_2 = fix_sample(thisbatchsize, self.args.num_point, 4)
            uv_sampling_coors_ = torch.cat((uv_sampling_coors_1, uv_sampling_coors_2), dim=2) 
            uv_sampling_coors = copy.deepcopy(uv_sampling_coors_.detach())
        uv_sampling_coors = uv_sampling_coors.detach().contiguous()   # thisbatchsize, self.args.num_point, self.over_sampling_up_ratio, 2
        uv_sampling_coors.requires_grad=True
        ######### Set the uv_sampling_coors, Done.

        # compute the point-wise feature, updated with local pooling
        neighbour_indexes_feature_extract = torch_tensor_functions.get_neighbor_index(points_sparse, self.args.neighbor_k)   # bs, vertice_num, neighbor_num
        points_in_local_patch_form = torch_tensor_functions.indexing_by_id(points_sparse,neighbour_indexes_feature_extract)
        points_in_local_patch_form = points_in_local_patch_form - points_sparse.view(thisbatchsize,self.args.num_point,1,3)
        points_in_local_patch_form = points_in_local_patch_form.view(thisbatchsize*self.args.num_point, self.args.neighbor_k, 3)
        sparse_embedding = self.emb_nn_sparse(points_in_local_patch_form.transpose(1,2))  # B*num_point, self.emb_dims, self.neighbor_k
        sparse_embedding = torch.max(sparse_embedding,dim=-1,keepdim=False)[0].view(thisbatchsize,self.args.num_point,-1).permute(0,2,1)
        local_features_pooling = torch_tensor_functions.indexing_neighbor(sparse_embedding.transpose(1,2), neighbour_indexes_).permute(0,3,2,1)
        local_features_pooling = torch.max(local_features_pooling, dim=2, keepdim=False)[0]
        sparse_embedding = torch.cat((sparse_embedding,local_features_pooling),dim=1)
        sparse_embedding = sparse_embedding.permute(0,2,1)  # thisbatchsize, self.args.num_point, self.emb_dims*2
        

        # get the uv_sampling_coors_id_in_sparse
        uv_sampling_coors_id_in_sparse = torch.arange(self.args.num_point).view(1,-1,1).long()
        uv_sampling_coors_id_in_sparse = uv_sampling_coors_id_in_sparse.expand(thisbatchsize,-1,self.over_sampling_up_ratio).reshape(thisbatchsize,-1,1)
        upsampled_p, upsampled_np = self.convert_uv_to_xyzn(uv_sampling_coors.reshape(thisbatchsize,-1,2), uv_sampling_coors_id_in_sparse, sparse_embedding, points_sparse) # thisbatchsize, self.args.num_point*self.over_sampling_up_ratio, 3
        

        upsampled_p_fps_id = pn2_utils.furthest_point_sample(upsampled_p.contiguous(), self.up_ratio*self.args.num_point)
        querying_points_3d = pn2_utils.gather_operation(upsampled_p.permute(0, 2, 1).contiguous(), upsampled_p_fps_id)
        querying_points_n_3d = pn2_utils.gather_operation(upsampled_np.permute(0, 2, 1).contiguous(), upsampled_p_fps_id)
        querying_points_3d = querying_points_3d.permute(0,2,1).contiguous()
        querying_points_n_3d = querying_points_n_3d.permute(0,2,1).contiguous()

        # Get the final upsampled points from the 3D querying points
        glued_points, glued_normals = self.project_3d_query_point_to_patches(querying_points_3d, querying_points_n_3d, points_sparse, upsampled_p, upsampled_np)

        

        

        # Notice that the returned uv_sampling_coors is not differentiable, just used to compute the loss.
        return upsampled_p, upsampled_np, uv_sampling_coors, querying_points_3d, querying_points_n_3d, glued_points, glued_normals

    def project_3d_query_point_to_patches(self, querying_points_3d, querying_points_n_3d, points_sparse, upsampled_p, upsampled_np):
        # All3dQueryPointNum = self.args.num_point * self.up_ratio
        # All2dQueryPointNum = self.args.num_point * self.over_sampling_up_ratio
        # querying_points_3d     | should be in size : thisbatchsize, All3dQueryPointNum, 3
        # querying_points_n_3d   | should be in size : thisbatchsize, All3dQueryPointNum, 3
        # points_sparse          | should be in size : thisbatchsize, self.args.num_point, 3
        # upsampled_p            | should be in size : thisbatchsize, All2dQueryPointNum, 3
        # upsampled_np           | should be in size : thisbatchsize, All2dQueryPointNum, 3
        
        thisbatchsize = querying_points_3d.size()[0]
        All3dQueryPointNum = querying_points_3d.size()[1]
        All2dQueryPointNum = upsampled_p.size()[1]
        #### Distribute the 3d querying points to the center points. ####
        # 1. compute the distance map bwtween 3d querying points and center points. 
        querying_points_3d__center_p__dismap = torch_tensor_functions.compute_sqrdis_map(querying_points_3d, points_sparse)    # thisbatchsize, All3dQueryPointNum, self.args.num_point
        # 2. find the neighbour ID from the 3d querying points to the center points.
        querying_points_3d_distribute_to_centers_nei_id = torch.topk(querying_points_3d__center_p__dismap, k=self.args.glue_neighbor, dim=2, largest=False)[1] # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor
        # 3. find the nearest distance from the 3d querying points to the center points.
        querying_points_3d_distribute_to_centers_nei_dis = torch.topk(querying_points_3d__center_p__dismap, k=self.args.glue_neighbor, dim=2, largest=False)[0].detach() # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor 
        # 4. find the nearest points coordinates from the 3d querying points to the center points.
        querying_points_3d_distribute_to_centers_nei_coor = torch_tensor_functions.indexing_by_id(points_sparse, querying_points_3d_distribute_to_centers_nei_id) # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3 
        # 5. compute the weight of the 3d querying points distributed to their neighbour center points. 
        Alpha_glue = 1.0/torch.mean(querying_points_3d_distribute_to_centers_nei_dis) 
        querying_points_3d_distribute_to_centers_nei_weight = torch.exp( -1 * Alpha_glue * querying_points_3d_distribute_to_centers_nei_dis )
        querying_points_3d_distribute_to_centers_nei_weight = querying_points_3d_distribute_to_centers_nei_weight / (torch.sum(querying_points_3d_distribute_to_centers_nei_weight,dim=-1,keepdim=True)+0.0000001)  # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor. The last dim should sum up to 1.

        #### Project the 3d querying points to their neighbour patches. ####
        #### In this part, we can get a (thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3)-shaped tensor, which should be multiplied with the weight above.
        # For each 3d querying point's each neighbour patch , find the projection points in the patch.  
        querying_points_3d_ = querying_points_3d.view(thisbatchsize, All3dQueryPointNum, 1, 3)
        querying_points_3d_ = querying_points_3d_.expand(-1, -1, self.args.glue_neighbor, -1)
        querying_points_3d_ = querying_points_3d_.reshape(thisbatchsize, All3dQueryPointNum*self.args.glue_neighbor, 1, 3)  

        upsampled_p_ = upsampled_p.view(thisbatchsize, self.args.num_point, -1, 3)
        upsampled_np_ = upsampled_np.view(thisbatchsize, self.args.num_point, -1, 3)
        up_ratio_here = upsampled_p_.size()[2]
        upsampled_p_ = upsampled_p_.reshape(thisbatchsize, self.args.num_point, -1)
        upsampled_np_ = upsampled_np_.reshape(thisbatchsize, self.args.num_point, -1)
        all_queried_patches = torch_tensor_functions.indexing_by_id(upsampled_p_, querying_points_3d_distribute_to_centers_nei_id)
        all_queried_patchesn = torch_tensor_functions.indexing_by_id(upsampled_np_, querying_points_3d_distribute_to_centers_nei_id)
        all_queried_patches = all_queried_patches.view(thisbatchsize, All3dQueryPointNum*self.args.glue_neighbor, up_ratio_here, 3)
        all_queried_patchesn = all_queried_patchesn.view(thisbatchsize, All3dQueryPointNum*self.args.glue_neighbor, up_ratio_here, 3)
        

        dis_from_3d_querying_points_to_its_corresponidng_patch = querying_points_3d_ - all_queried_patches
        dis_from_3d_querying_points_to_its_corresponidng_patch = torch.sum( torch.mul(dis_from_3d_querying_points_to_its_corresponidng_patch, dis_from_3d_querying_points_to_its_corresponidng_patch) , dim = -1, keepdim = False)
        nei_id_from_3d_querying_points_to_its_corresponidng_patch = torch.topk(dis_from_3d_querying_points_to_its_corresponidng_patch, dim =-1, k=self.args.proj_neighbor,largest=False)[1].reshape(thisbatchsize*All3dQueryPointNum*self.args.glue_neighbor, 1, self.args.proj_neighbor)
        nei_dis_from_3d_querying_points_to_its_corresponidng_patch = torch.topk(dis_from_3d_querying_points_to_its_corresponidng_patch, dim =-1, k=self.args.proj_neighbor,largest=False)[0].reshape(thisbatchsize*All3dQueryPointNum*self.args.glue_neighbor, 1, self.args.proj_neighbor)
        all_queried_patches_ = all_queried_patches.view(thisbatchsize*All3dQueryPointNum*self.args.glue_neighbor, up_ratio_here, 3)
        all_queried_patchesn_ = all_queried_patchesn.view(thisbatchsize*All3dQueryPointNum*self.args.glue_neighbor, up_ratio_here, 3)
        nei_coor_from_3d_querying_points_to_its_corresponidng_patch = torch_tensor_functions.indexing_by_id(all_queried_patches_, nei_id_from_3d_querying_points_to_its_corresponidng_patch)
        nei_ncoor_from_3d_querying_points_to_its_corresponidng_patch = torch_tensor_functions.indexing_by_id(all_queried_patchesn_, nei_id_from_3d_querying_points_to_its_corresponidng_patch)
        nei_weight_from_3d_querying_points_to_its_corresponidng_patch = torch.exp( -1000 * nei_dis_from_3d_querying_points_to_its_corresponidng_patch)
        nei_weight_from_3d_querying_points_to_its_corresponidng_patch = nei_weight_from_3d_querying_points_to_its_corresponidng_patch / (torch.sum(nei_weight_from_3d_querying_points_to_its_corresponidng_patch, dim=-1, keepdim=True) +0.0000001 )
        nei_weight_from_3d_querying_points_to_its_corresponidng_patch = nei_weight_from_3d_querying_points_to_its_corresponidng_patch.view(thisbatchsize*All3dQueryPointNum*self.args.glue_neighbor, 1, self.args.proj_neighbor,1)
        projected_points = torch.sum(nei_weight_from_3d_querying_points_to_its_corresponidng_patch * nei_coor_from_3d_querying_points_to_its_corresponidng_patch, dim =2, keepdim=False )
        projected_pointsn = torch.sum(nei_weight_from_3d_querying_points_to_its_corresponidng_patch * nei_ncoor_from_3d_querying_points_to_its_corresponidng_patch, dim =2, keepdim=False )
        projected_points = projected_points.view(thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3)  # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3
        projected_pointsn = projected_pointsn.view(thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3)  # thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 3
        
        projected_pointsn_sign = projected_pointsn.detach()
        projected_pointsn_sign_ref = projected_pointsn_sign[:,:,0:1,:].expand(-1,-1,self.args.glue_neighbor,-1)
        projected_pointsn_sign = torch.sum(torch.mul(projected_pointsn_sign, projected_pointsn_sign_ref) ,dim=-1, keepdim=True ).expand(-1,-1,-1,3)
        projected_pointsn_sign = torch.sign(projected_pointsn_sign+0.1)
        
        # correct the direction of the normals.
        projected_pointsn = torch.mul(projected_pointsn, projected_pointsn_sign)
        #### Glue the 3d upsampled points. ####
        glued_points = torch.sum( projected_points * querying_points_3d_distribute_to_centers_nei_weight.view(thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 1), dim = 2 , keepdim=False)
        glued_normals = torch.sum( projected_pointsn * querying_points_3d_distribute_to_centers_nei_weight.view(thisbatchsize, All3dQueryPointNum, self.args.glue_neighbor, 1), dim = 2 , keepdim=False)
        return glued_points, glued_normals
    
    
    def convert_uv_to_xyzn(self, uv_coor, uv_coor_idx_in_sparse, sparse_embedding, points_sparse):
        # uv_coor                | should be in size : thisbatchsize, All2dQueryPointNum, 2
        # uv_coor_idx_in_sparse  | should be in size : thisbatchsize, All2dQueryPointNum, 1
        # sparse_embedding       | should be in size : thisbatchsize, sparse_point_num, embedding_dim
        # points_sparse          | should be in size : thisbatchsize, sparse_point_num, 3
        thisbatchsize = uv_coor.size()[0]
        All2dQueryPointNum = uv_coor.size()[1]
        converted2to3_p = self.convert_uv_to_xyz(uv_coor, uv_coor_idx_in_sparse, sparse_embedding, points_sparse)
        
        converted2to3_p_x = converted2to3_p[:,:,0:1].reshape(thisbatchsize*All2dQueryPointNum,1)
        grad_x_uv = cal_grad(uv_coor, converted2to3_p_x).reshape(thisbatchsize*All2dQueryPointNum,2,1)
        converted2to3_p_y = converted2to3_p[:,:,1:2].reshape(thisbatchsize*All2dQueryPointNum,1)
        grad_y_uv = cal_grad(uv_coor, converted2to3_p_y).reshape(thisbatchsize*All2dQueryPointNum,2,1)
        converted2to3_p_z = converted2to3_p[:,:,2:3].reshape(thisbatchsize*All2dQueryPointNum,1)
        grad_z_uv = cal_grad(uv_coor, converted2to3_p_z).reshape(thisbatchsize*All2dQueryPointNum,2,1)

        grad_uv = torch.cat((grad_x_uv, grad_y_uv, grad_z_uv), dim=-1)
        grad_u = grad_uv[:,0:1,:].view(-1,3)
        grad_v = grad_uv[:,1:2,:].view(-1,3)

        converted2to3_np = torch.cross(grad_u.reshape(-1,3), grad_v.reshape(-1,3))
        converted2to3_np_norm = torch.norm(converted2to3_np, dim=1).view(-1,1) +0.000001
        converted2to3_np = converted2to3_np/converted2to3_np_norm
        converted2to3_np = converted2to3_np.view(thisbatchsize,-1,3)

        return converted2to3_p, converted2to3_np


    def convert_uv_to_xyz(self, uv_coor, uv_coor_idx_in_sparse, sparse_embedding, points_sparse):
        # uv_coor                | should be in size : thisbatchsize, All2dQueryPointNum, 2
        # uv_coor_idx_in_sparse  | should be in size : thisbatchsize, All2dQueryPointNum, 1
        # sparse_embedding       | should be in size : thisbatchsize, sparse_point_num, embedding_dim
        # points_sparse          | should be in size : thisbatchsize, sparse_point_num, 3
        thisbatchsize = uv_coor.size()[0]
        All2dQueryPointNum = uv_coor.size()[1]
        coding_dim = 4*self.args.pe_out_L + 2
        uv_encoded = position_encoding(uv_coor.reshape(-1,2).contiguous(), self.args.pe_out_L).view(thisbatchsize, All2dQueryPointNum, coding_dim).permute(0,2,1) # bs, coding_dim, All2dQueryPointNum
        indexed_sparse_feature = torch_tensor_functions.indexing_by_id(sparse_embedding, uv_coor_idx_in_sparse)  # bs, All2dQueryPointNum, 1, embedding_num 
        indexed_sparse_feature = indexed_sparse_feature.view(thisbatchsize, All2dQueryPointNum, -1).transpose(2,1)  # bs, embedding_num, All2dQueryPointNum
        coding_with_feature = torch.cat((indexed_sparse_feature, uv_encoded), dim=1)
        out_p = self.convert_feature_to_point_2to3(coding_with_feature).view(thisbatchsize, -1, All2dQueryPointNum).permute(0,2,1)
        indexed_center_points = torch_tensor_functions.indexing_by_id(points_sparse, uv_coor_idx_in_sparse).view(thisbatchsize, All2dQueryPointNum, 3)
        out_p = out_p + indexed_center_points
        return out_p
    
    def convert_xyz_to_uv(self, xyz_coor, xyz_coor_idx_in_sparse, sparse_embedding, points_sparse):
        # xyz_coor               | should be in size : thisbatchsize, All2dQueryPointNum, 3
        # uv_coor_idx_in_sparse  | should be in size : thisbatchsize, All2dQueryPointNum, 1
        # sparse_embedding       | should be in size : thisbatchsize, sparse_point_num, embedding_dim
        # points_sparse          | should be in size : thisbatchsize, sparse_point_num, 3
        # return : out_uv        | should be in size : thisbatchsize, All2dQueryPointNum, 2
        thisbatchsize = xyz_coor.size()[0]
        All2dQueryPointNum = xyz_coor.size()[1]
        coding_dim = 6*self.args.pe_out_L + 3
        indexed_center_points = torch_tensor_functions.indexing_by_id(points_sparse, xyz_coor_idx_in_sparse).view(thisbatchsize, All2dQueryPointNum, 3)
        xyz_coor_remove_center = xyz_coor - indexed_center_points
        xyz_encoded = position_encoding(xyz_coor.reshape(-1,3), self.args.pe_out_L).view(thisbatchsize, All2dQueryPointNum, coding_dim).permute(0,2,1) # bs, coding_dim, All2dQueryPointNum
        indexed_sparse_feature = torch_tensor_functions.indexing_by_id(sparse_embedding, xyz_coor_idx_in_sparse)  # bs, All2dQueryPointNum, 1, embedding_num 
        indexed_sparse_feature = indexed_sparse_feature.view(thisbatchsize, All2dQueryPointNum, -1).transpose(2,1)  # bs, embedding_num, All2dQueryPointNum
        coding_with_feature = torch.cat((xyz_encoded, indexed_sparse_feature), dim = 1)
        out_uv = self.convert_feature_to_point_3to2(coding_with_feature).view(thisbatchsize, -1, All2dQueryPointNum).permute(0,2,1)
        return out_uv
        

#### Convert a string to num_list ####      
def convert_str_2_list(str_):
    words = str_.split(' ')
    trt = [int(x) for x in words]
    return trt
#### Compute the position code for uv or xyz. ####
def position_encoding(input_uv, pe_out_L):
    ## The input_uv should be with shape (-1, X)
    ## The returned tensor should be with shape (-1, X+2*X*L)
    ## X = 2/3 if the input is uv/xyz.
    trt = input_uv
    for i in range(pe_out_L):
        trt = torch.cat((trt, torch.sin(input_uv*(2**i)*(3.14159265))) , dim=-1 )
        trt = torch.cat((trt, torch.cos(input_uv*(2**i)*(3.14159265))) , dim=-1 )
    return trt
#### Sample uv by a fixed manner. #### 
def fix_sample(thisbatchsize, num_point, up_ratio, if_random=False):
    if if_random==True: 
        print('Random sampling mode is not supported right now.')
        exit()
    if up_ratio == 4:
        one_point_fixed = [ [ [0,0] for i in range(2)] for j in range(2) ]
        for i in range(2):
            for j in range(2):
                one_point_fixed[i][j][0] = (i/1) *2 -1
                one_point_fixed[i][j][1] = (j/1) *2 -1
        one_point_fixed = np.array(one_point_fixed).reshape(4,2)
        one_batch_uv2d_random_fixed = np.expand_dims(one_point_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.expand_dims(one_batch_uv2d_random_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.tile(one_batch_uv2d_random_fixed,[thisbatchsize, num_point, 1,1])
        one_batch_uv2d_random_fixed_tensor = torch.from_numpy(one_batch_uv2d_random_fixed).cuda().float()
        return one_batch_uv2d_random_fixed_tensor
    if up_ratio == 9:
        one_point_fixed = [ [ [0,0] for i in range(3)] for j in range(3) ]
        for i in range(3):
            for j in range(3):
                one_point_fixed[i][j][0] = (i/2) *2 -1
                one_point_fixed[i][j][1] = (j/2) *2 -1
        one_point_fixed = np.array(one_point_fixed).reshape(9,2)
        one_batch_uv2d_random_fixed = np.expand_dims(one_point_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.expand_dims(one_batch_uv2d_random_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.tile(one_batch_uv2d_random_fixed,[thisbatchsize, num_point, 1,1])
        one_batch_uv2d_random_fixed_tensor = torch.from_numpy(one_batch_uv2d_random_fixed).cuda().float()
        return one_batch_uv2d_random_fixed_tensor
    if up_ratio == 16:
        one_point_fixed = [ [ [0,0] for i in range(4)] for j in range(4) ]
        for i in range(4):
            for j in range(4):
                one_point_fixed[i][j][0] = (i/3) *2 -1
                one_point_fixed[i][j][1] = (j/3) *2 -1
        one_point_fixed = np.array(one_point_fixed).reshape(16,2)
        one_batch_uv2d_random_fixed = np.expand_dims(one_point_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.expand_dims(one_batch_uv2d_random_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.tile(one_batch_uv2d_random_fixed,[thisbatchsize, num_point, 1,1])
        one_batch_uv2d_random_fixed_tensor = torch.from_numpy(one_batch_uv2d_random_fixed).cuda().float()
        return one_batch_uv2d_random_fixed_tensor
    if up_ratio == 64:
        one_point_fixed = [ [ [0,0] for i in range(8)] for j in range(8) ]
        for i in range(8):
            for j in range(8):
                one_point_fixed[i][j][0] = (i/7) *2 -1
                one_point_fixed[i][j][1] = (j/7) *2 -1
        one_point_fixed = np.array(one_point_fixed).reshape(64,2)
        one_batch_uv2d_random_fixed = np.expand_dims(one_point_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.expand_dims(one_batch_uv2d_random_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.tile(one_batch_uv2d_random_fixed,[thisbatchsize, num_point, 1,1])
        one_batch_uv2d_random_fixed_tensor = torch.from_numpy(one_batch_uv2d_random_fixed).cuda().float()
        return one_batch_uv2d_random_fixed_tensor
    if up_ratio == 1024:
        one_point_fixed = [ [ [0,0] for i in range(32)] for j in range(32) ]
        for i in range(32):
            for j in range(32):
                one_point_fixed[i][j][0] = (i/31) *2 -1
                one_point_fixed[i][j][1] = (j/31) *2 -1
        one_point_fixed = np.array(one_point_fixed).reshape(1024,2)
        one_batch_uv2d_random_fixed = np.expand_dims(one_point_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.expand_dims(one_batch_uv2d_random_fixed,axis=0)
        one_batch_uv2d_random_fixed = np.tile(one_batch_uv2d_random_fixed,[thisbatchsize, num_point, 1,1])
        one_batch_uv2d_random_fixed_tensor = torch.from_numpy(one_batch_uv2d_random_fixed).cuda().float()
        return one_batch_uv2d_random_fixed_tensor
    else:
        print('This up_ratio ('+str(up_ratio)+') is not supported now. You can try the random mode!')
        exit()
#### Sample uv uniformly in (-1,1). #### 
def uniform_random_sample(thisbatchsize, num_point, up_ratio):
    # return : randomly and uniformly sampled uv_coors   |   Its shape should be : thisbatchsize, num_point, up_ratio, 2
    res_ = torch.rand(thisbatchsize*num_point, 4*up_ratio, 3)*2-1
    res_ = res_.cuda()
    res_[:,:,2:]*=0
    furthest_point_index = pn2_utils.furthest_point_sample(res_,up_ratio)
    uniform_res_ = pn2_utils.gather_operation(res_.permute(0, 2, 1).contiguous(), furthest_point_index)
    uniform_res_ = uniform_res_.permute(0,2,1).contiguous()
    uniform_res_ = uniform_res_[:,:,:2].view(thisbatchsize, num_point, up_ratio, 2)
    return uniform_res_
#### Compute the grad ####
def cal_grad(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad = False, device = outputs.device)
    points_grad = grad(
        outputs = outputs,
        inputs = inputs,
        grad_outputs = d_points,
        create_graph = True,
        retain_graph = True,
        only_inputs = True)[0]
    return points_grad



######## TODO: END PART: OUR OWN NETWORK ########


================================================
FILE: model/conpu_v6/pointnet2/__init__.py
================================================


================================================
FILE: model/conpu_v6/pointnet2/pointnet2_modules.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from . import pointnet2_utils
from . import pytorch_utils as pt_utils
from typing import List


class _PointnetSAModuleBase(nn.Module):

    def __init__(self):
        super().__init__()
        self.npoint = None
        self.groupers = None
        self.mlps = None
        self.pool_method = 'max_pool'

    def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, npoint=None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
        """
        :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
        :param features: (B, N, C) tensor of the descriptors of the the features
        :param new_xyz:
        :return:
            new_xyz: (B, npoint, 3) tensor of the new features' xyz
            new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
        """
        if npoint is not None:
            self.npoint = npoint
        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if new_xyz is None:
            new_xyz = pointnet2_utils.gather_operation(
                xyz_flipped,
                pointnet2_utils.furthest_point_sample(xyz, self.npoint)
            ).transpose(1, 2).contiguous() if self.npoint is not None else None

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)
            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
            if self.pool_method == 'max_pool':
                new_features = F.max_pool2d(
                    new_features, kernel_size=[1, new_features.size(3)]
                )  # (B, mlp[-1], npoint, 1)
            elif self.pool_method == 'avg_pool':
                new_features = F.avg_pool2d(
                    new_features, kernel_size=[1, new_features.size(3)]
                )  # (B, mlp[-1], npoint, 1)
            else:
                raise NotImplementedError

            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)


class PointnetSAModuleMSG(_PointnetSAModuleBase):
    """Pointnet set abstraction layer with multiscale grouping"""

    def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
                 use_xyz: bool = True, use_res = False, pool_method='max_pool', instance_norm=False):
        """
        :param npoint: int
        :param radii: list of float, list of radii to group with
        :param nsamples: list of int, number of samples in each ball query
        :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
        :param bn: whether to use batchnorm
        :param use_xyz:
        :param pool_method: max_pool / avg_pool
        :param instance_norm: whether to use instance_norm
        """
        super().__init__()

        assert len(radii) == len(nsamples) == len(mlps)

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
                if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            if use_res:
                self.mlps.append(pt_utils.SharedResMLP(mlp_spec, bn=bn))
            else:
                self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
        self.pool_method = pool_method


class PointnetSAModule(PointnetSAModuleMSG):
    """Pointnet set abstraction layer"""

    def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
                 bn: bool = True, use_xyz: bool = True, use_res = False, pool_method='max_pool', instance_norm=False):
        """
        :param mlp: list of int, spec of the pointnet before the global max_pool
        :param npoint: int, number of features
        :param radius: float, radius of ball
        :param nsample: int, number of samples in the ball query
        :param bn: whether to use batchnorm
        :param use_xyz:
        :param pool_method: max_pool / avg_pool
        :param instance_norm: whether to use instance_norm
        """
        super().__init__(
            mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, use_res=use_res,
            pool_method=pool_method, instance_norm=instance_norm
        )


class PointNetSSG_Base(PointnetSAModuleMSG):
    def __init__(self, npoint, nsample, radius, in_channel, out_channel, bn=True, use_xyz=False):
        super().__init__(
            mlps=[[in_channel, out_channel, out_channel, out_channel]], 
            npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, use_res=False)


class PointnetFPModule(nn.Module):
    r"""Propigates the features of one set to another"""

    def __init__(self, *, mlp: List[int], bn: bool = True):
        """
        :param mlp: list of int
        :param bn: whether to use batchnorm
        """
        super().__init__()
        self.mlp = pt_utils.SharedMLP(mlp, bn=bn)

    def forward(
            self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
    ) -> torch.Tensor:
        """
        :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
        :param known: (B, m, 3) tensor of the xyz positions of the known features
        :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
        :param known_feats: (B, C2, m) tensor of features to be propigated
        :return:
            new_features: (B, mlp[-1], n) tensor of the features of the unknown features
        """
        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
        else:
            interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))

        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats], dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)


if __name__ == "__main__":
    pass


================================================
FILE: model/conpu_v6/pointnet2/pointnet2_utils.py
================================================
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import Tuple

import pointnet2_cuda as pointnet2


class FurthestPointSampling(Function):
    @staticmethod
    def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
        """
        Uses iterative furthest point sampling to select a set of npoint features that have the largest
        minimum distance
        :param ctx:
        :param xyz: (B, N, 3) where N > npoint
        :param npoint: int, number of features in the sampled set
        :return:
             output: (B, npoint) tensor containing the set
        """
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        output = torch.cuda.IntTensor(B, npoint)
        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)

        pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
        return output

    @staticmethod
    def backward(xyz, a=None):
        return None, None


furthest_point_sample = FurthestPointSampling.apply


class GatherOperation(Function):

    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param features: (B, C, N)
        :param idx: (B, npoint) index tensor of the features to gather
        :return:
            output: (B, C, npoint)
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()

        B, npoint = idx.size()
        _, C, N = features.size()
        output = torch.cuda.FloatTensor(B, C, npoint)

        pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)

        ctx.for_backwards = (idx, C, N)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        idx, C, N = ctx.for_backwards
        B, npoint = idx.size()

        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
        grad_out_data = grad_out.data.contiguous()
        pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
        return grad_features, None


gather_operation = GatherOperation.apply


class ThreeNN(Function):

    @staticmethod
    def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Find the three nearest neighbors of unknown in known
        :param ctx:
        :param unknown: (B, N, 3)
        :param known: (B, M, 3)
        :return:
            dist: (B, N, 3) l2 distance to the three nearest neighbors
            idx: (B, N, 3) index of 3 nearest neighbors
        """
        assert unknown.is_contiguous()
        assert known.is_contiguous()

        B, N, _ = unknown.size()
        m = known.size(1)
        dist2 = torch.cuda.FloatTensor(B, N, 3)
        idx = torch.cuda.IntTensor(B, N, 3)

        pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply


class ThreeInterpolate(Function):

    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """
        Performs weight linear interpolation on 3 features
        :param ctx:
        :param features: (B, C, M) Features descriptors to be interpolated from
        :param idx: (B, n, 3) three nearest neighbors of the target features in features
        :param weight: (B, n, 3) weights
        :return:
            output: (B, C, N) tensor of the interpolated features
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()
        assert weight.is_contiguous()

        B, c, m = features.size()
        n = idx.size(1)
        ctx.save_for_backward(idx, weight, features)
        output = torch.cuda.FloatTensor(B, c, n)

        pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
        return output

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        :param ctx:
        :param grad_out: (B, C, N) tensor with gradients of outputs
        :return:
            grad_features: (B, C, M) tensor with gradients of features
            None:
            None:
        """
        idx, weight, features = ctx.saved_tensors
        B, c, m = features.size()
        n = idx.size(1)

        grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
        grad_out_data = grad_out.data.contiguous()

        pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
        return grad_features, None, None


three_interpolate = ThreeInterpolate.apply


class GroupingOperation(Function):

    @staticmethod
    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param features: (B, C, N) tensor of features to group
        :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
        :return:
            output: (B, C, npoint, nsample) tensor
        """
        assert features.is_contiguous()
        assert idx.is_contiguous()

        B, nfeatures, nsample = idx.size()
        _, C, N = features.size()
        output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)

        pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)

        ctx.for_backwards = (idx, N)
        return output

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param ctx:
        :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
        :return:
            grad_features: (B, C, N) gradient of the features
        """
        idx, N = ctx.for_backwards

        B, C, npoint, nsample = grad_out.size()
        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())

        grad_out_data = grad_out.data.contiguous()
        pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
        return grad_features, None


grouping_operation = GroupingOperation.apply


class BallQuery(Function):

    @staticmethod
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param radius: float, radius of the balls
        :param nsample: int, maximum number of features in the balls
        :param xyz: (B, N, 3) xyz coordinates of the features
        :param new_xyz: (B, npoint, 3) centers of the ball query
        :return:
            idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        assert new_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = new_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()

        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply


class QueryAndGroup(nn.Module):
    def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
        """
        :param radius: float, radius of ball
        :param nsample: int, maximum number of features to gather in the ball
        :param use_xyz:
        """
        super().__init__()
        self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz

    def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
        """
        :param xyz: (B, N, 3) xyz coordinates of the features
        :param new_xyz: (B, npoint, 3) centroids
        :param features: (B, C, N) descriptors of the features
        :return:
            new_features: (B, 3 + C, npoint, nsample)
        """
        idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
        xyz_trans = xyz.transpose(1, 2).contiguous()
        grouped_xyz = grouping_operation(xyz_trans, idx)  # (B, 3, npoint, nsample)
        grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)

        if features is not None:
            grouped_features = grouping_operation(features, idx)
            if self.use_xyz:
                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (B, C + 3, npoint, nsample)
            else:
                new_features = grouped_features
        else:
            assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
            new_features = grouped_xyz

        return new_features


================================================
FILE: model/conpu_v6/pointnet2/pytorch_utils.py
================================================
import torch.nn as nn
from typing import List, Tuple
import torch.nn.functional as F

class EmptyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class SharedResMLP(nn.Module):
    def __init__(
            self,
            args: List[int],
            *,
            bn: bool = False,
            activation=nn.ReLU(inplace=True)):
        super().__init__()

        self.res_convs = nn.ModuleList()
        self.short_conn = nn.ModuleList()
        for i in range(len(args) - 1):
            in_ch = args[i]
            out_ch = args[i + 1]
            mid_ch = args[i + 1] // 2
            self.res_convs.append(
                nn.Sequential(
                    Conv2d(in_ch, mid_ch, bn=bn, activation=activation),
                    Conv2d(mid_ch, mid_ch, bn=bn, activation=activation),
                    Conv2d(mid_ch, out_ch, bn=bn, activation=None)))
            self.short_conn.append(
                EmptyModule() if in_ch == out_ch else \
                Conv2d(in_ch, out_ch, bn=bn, activation=None))

    def forward(self, x):
        for k in range(len(self.res_convs)):
            out_res = self.res_convs[k](x)
            out_short = self.short_conn[k](x)
            x = F.relu(out_res + out_short)
        return x
            

class SharedMLP(nn.Sequential):

    def __init__(
            self,
            args: List[int],
            *,
            bn: bool = False,
            activation=nn.ReLU(inplace=True),
            preact: bool = False,
            first: bool = False,
            name: str = "",
            instance_norm: bool = False,):
        super().__init__()

        for i in range(len(args) - 1):
            self.add_module(
                name + 'layer{}'.format(i),
                Conv2d(
                    args[i],
                    args[i + 1],
                    bn=(not first or not preact or (i != 0)) and bn,
                    activation=activation
                    if (not first or not preact or (i != 0)) else None,
                    preact=preact,
                    instance_norm=instance_norm
                )
            )


class _ConvBase(nn.Sequential):

    def __init__(
            self,
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=None,
            batch_norm=None,
            bias=True,
            preact=False,
            name="",
            instance_norm=False,
            instance_norm_func=None
    ):
        super().__init__()

        bias = bias and (not bn)
        conv_unit = conv(
            in_size,
            out_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias
        )
        init(conv_unit.weight)
        if bias:
            nn.init.constant_(conv_unit.bias, 0)

        if bn:
            if not preact:
                bn_unit = batch_norm(out_size)
            else:
                bn_unit = batch_norm(in_size)
        if instance_norm:
            if not preact:
                in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
            else:
                in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)

        if preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)

            if not bn and instance_norm:
                self.add_module(name + 'in', in_unit)

        self.add_module(name + 'conv', conv_unit)

        if not preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)

            if not bn and instance_norm:
                self.add_module(name + 'in', in_unit)


class _BNBase(nn.Sequential):

    def __init__(self, in_size, batch_norm=None, name=""):
        super().__init__()
        self.add_module(name + "bn", batch_norm(in_size))

        nn.init.constant_(self[0].weight, 1.0)
        nn.init.constant_(self[0].bias, 0)


class BatchNorm1d(_BNBase):

    def __init__(self, in_size: int, *, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)


class BatchNorm2d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
        

class BatchNorm3d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)


class Conv1d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: int = 1,
            stride: int = 1,
            padding: int = 0,
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = "",
            instance_norm=False
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv1d,
            batch_norm=BatchNorm1d,
            bias=bias,
            preact=preact,
            name=name,
            instance_norm=instance_norm,
            instance_norm_func=nn.InstanceNorm1d
        )


class Conv2d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: Tuple[int, int] = (1, 1),
            stride: Tuple[int, int] = (1, 1),
            padding: Tuple[int, int] = (0, 0),
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = "",
            instance_norm=False
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv2d,
            batch_norm=BatchNorm2d,
            bias=bias,
            preact=preact,
            name=name,
            instance_norm=instance_norm,
            instance_norm_func=nn.InstanceNorm2d
        )

class Conv3d(_ConvBase):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            kernel_size: Tuple[int, int, int] = (1, 1, 1),
            stride: Tuple[int, int, int] = (1, 1, 1),
            padding: Tuple[int, int, int] = (0, 0, 0),
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=nn.init.kaiming_normal_,
            bias: bool = True,
            preact: bool = False,
            name: str = "",
            instance_norm=False
    ):
        super().__init__(
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=nn.Conv3d,
            batch_norm=BatchNorm3d,
            bias=bias,
            preact=preact,
            name=name,
            instance_norm=instance_norm,
            instance_norm_func=nn.InstanceNorm3d
        )


class FC(nn.Sequential):

    def __init__(
            self,
            in_size: int,
            out_size: int,
            *,
            activation=nn.ReLU(inplace=True),
            bn: bool = False,
            init=None,
            preact: bool = False,
            name: str = ""
    ):
        super().__init__()

        fc = nn.Linear(in_size, out_size, bias=not bn)
        if init is not None:
            init(fc.weight)
        if not bn:
            nn.init.constant(fc.bias, 0)

        if preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(in_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)

        self.add_module(name + 'fc', fc)

        if not preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(out_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)



================================================
FILE: model/conpu_v6/pointnet2/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='pointnet2',
    ext_modules=[
        CUDAExtension('pointnet2_cuda', [
            'src/pointnet2_api.cpp',
            
            'src/ball_query.cpp', 
            'src/ball_query_gpu.cu',
            'src/group_points.cpp', 
            'src/group_points_gpu.cu',
            'src/interpolate.cpp', 
            'src/interpolate_gpu.cu',
            'src/sampling.cpp', 
            'src/sampling_gpu.cu',
        ],
        extra_compile_args={'cxx': ['-g'],
                            'nvcc': ['-O2']})
    ],
    cmdclass={'build_ext': BuildExtension}
)


================================================
FILE: model/conpu_v6/pointnet2/src/ball_query.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"

extern THCState *state;

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
    CHECK_INPUT(new_xyz_tensor);
    CHECK_INPUT(xyz_tensor);
    const float *new_xyz = new_xyz_tensor.data<float>();
    const float *xyz = xyz_tensor.data<float>();
    int *idx = idx_tensor.data<int>();
    
    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
    return 1;
}


================================================
FILE: model/conpu_v6/pointnet2/src/ball_query_gpu.cu
================================================
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "ball_query_gpu.h"
#include "cuda_utils.h"


__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 
    const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx: (B, M, nsample)
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= m) return;

    new_xyz += bs_idx * m * 3 + pt_idx * 3;
    xyz += bs_idx * n * 3;
    idx += bs_idx * m * nsample + pt_idx * nsample;

    float radius2 = radius * radius;
    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    int cnt = 0;
    for (int k = 0; k < n; ++k) {
        float x = xyz[k * 3 + 0];
        float y = xyz[k * 3 + 1];
        float z = xyz[k * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < radius2){
            if (cnt == 0){
                for (int l = 0; l < nsample; ++l) {
                    idx[l] = k;
                }
            }
            idx[cnt] = k;
            ++cnt;
            if (cnt >= nsample) break;
        }
    }
}


void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
    const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
    // new_xyz: (B, M, 3)
    // xyz: (B, N, 3)
    // output:
    //      idx: (B, M, nsample)

    cudaError_t err;

    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    ball_query_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: model/conpu_v6/pointnet2/src/ball_query_gpu.h
================================================
#ifndef _BALL_QUERY_GPU_H
#define _BALL_QUERY_GPU_H

#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>

int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
	at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);

void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 
	const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);

#endif


================================================
FILE: model/conpu_v6/pointnet2/src/cuda_utils.h
================================================
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#include <cmath>

#define TOTAL_THREADS 1024
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))

inline int opt_n_threads(int work_size) {
    const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

    return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
#endif


================================================
FILE: model/conpu_v6/pointnet2/src/group_points.cpp
================================================
#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"

extern THCState *state;


int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {

    float *grad_points = grad_points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    const float *grad_out = grad_out_tensor.data<float>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

    group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
    return 1;
}


int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {

    const float *points = points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

    group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
    return 1;
}


================================================
FILE: model/conpu_v6/pointnet2/src/group_points_gpu.cu
================================================
#include <stdio.h>
#include <stdlib.h>

#include "cuda_utils.h"
#include "group_points_gpu.h"


__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 
    const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
    // grad_out: (B, C, npoints, nsample)
    // idx: (B, npoints, nsample)
    // output:
    //      grad_points: (B, C, N)
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int pt_idx = index / nsample;
    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

    int sample_idx = index % nsample;
    grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 
    
    atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
}

void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
    // grad_out: (B, C, npoints, nsample)
    // idx: (B, npoints, nsample)
    // output:
    //      grad_points: (B, C, N)
    cudaError_t err;
    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    group_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 
    const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
    // points: (B, C, N)
    // idx: (B, npoints, nsample)
    // output:
    //      out: (B, C, npoints, nsample)
    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int pt_idx = index / nsample;
    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

    int sample_idx = index % nsample;

    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 
    int in_idx = bs_idx * c * n + c_idx * n + idx[0];
    int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;

    out[out_idx] = points[in_idx];
}


void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
    const float *points, const int *idx, float *out, cudaStream_t stream) {
    // points: (B, C, N)
    // idx: (B, npoints, nsample)
    // output:
    //      out: (B, C, npoints, nsample)
    cudaError_t err;
    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    group_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, points, idx, out);
    // cudaDeviceSynchronize();  // for using printf in kernel function
    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


================================================
FILE: model/conpu_v6/pointnet2/src/group_points_gpu.h
================================================
#ifndef _GROUP_POINTS_GPU_H
#define _GROUP_POINTS_GPU_H

#include <torch/serialize/tensor.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>


int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);

void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
    const float *points, const int *idx, float *out, cudaStream_t stream);

int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);

void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);

#endif


================================================
FILE: model/conpu_v6/pointnet2/src/interpolate.cpp
================================================
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"

extern THCState *state;


void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 
    at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
    const float *unknown = unknown_tensor.data<float>();
    const float *known = known_tensor.data<float>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
}


void three_interpolate_wrapper_fast(int b, int c, int m, int n,
                         at::Tensor points_tensor,
                         at::Tensor idx_tensor,
                         at::Tensor weight_tensor,
                         at::Tensor out_tensor) {

    const float *points = points_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *out = out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
}

void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
                            at::Tensor grad_out_tensor,
                            at::Tensor idx_tensor,
                            at::Tensor weight_tensor,
                            at::Tensor grad_points_tensor) {

    const float *grad_out = grad_out_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *grad_points = grad_points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
}


================================================
FILE: model/conpu_v6/pointnet2/src/interpolate_gpu.cu
================================================
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "cuda_utils.h"
#include "interpolate_gpu.h"


__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 
    const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
    // unknown: (B, N, 3)
    // known: (B, M, 3)
    // output: 
    //      dist2: (B, N, 3)
    //      idx: (B, N, 3)
    
    int bs_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || pt_idx >= n) return;

    unknown += bs_idx * n * 3 + pt_idx * 3;
    known += bs_idx * m * 3;
    dist2 += bs_idx * n * 3 + pt_idx * 3;
    idx += bs_idx * n * 3 + pt_idx * 3;

    float ux = unknown[0];
    float uy = unknown[1];
    float uz = unknown[2];

    double best1 = 1e40, best2 = 1e40, best3 = 1e40;
    int besti1 = 0, besti2 = 0, besti3 = 0;
    for (int k = 0; k < m; ++k) {
        float x = known[k * 3 + 0];
        float y = known[k * 3 + 1];
        float z = known[k * 3 + 2];
        float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
        if (d < best1) {
            best3 = best2; besti3 = besti2;
            best2 = best1; besti2 = besti1;
            best1 = d; besti1 = k;
        } 
        else if (d < best2) {
            best3 = best2; besti3 = besti2;
            best2 = d; besti2 = k;
        } 
        else if (d < best3) {
            best3 = d; besti3 = k;
        }
    }
    dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
    idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
}


void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 
    const float *known, float *dist2, int *idx, cudaStream_t stream) {
    // unknown: (B, N, 3)
    // known: (B, M, 3)
    // output: 
    //      dist2: (B, N, 3)
    //      idx: (B, N, 3)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    three_nn_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known, dist2, idx);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 
    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
    // points: (B, C, M)
    // idx: (B, N, 3)
    // weight: (B, N, 3)
    // output:
    //      out: (B, C, N)

    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

    weight += bs_idx * n * 3 + pt_idx * 3;
    points += bs_idx * c * m + c_idx * m;
    idx += bs_idx * n * 3 + pt_idx * 3;
    out += bs_idx * c * n + c_idx * n;

    out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
}

void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 
    const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
    // points: (B, C, M)
    // idx: (B, N, 3)
    // weight: (B, N, 3)
    // output:
    //      out: (B, C, N)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);
    three_interpolate_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, m, n, points, idx, weight, out);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 
    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
    // grad_out: (B, C, N)
    // weight: (B, N, 3)
    // output:
    //      grad_points: (B, C, M)

    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
    
    grad_out += bs_idx * c * n + c_idx * n + pt_idx;
    weight += bs_idx * n * 3 + pt_idx * 3;
    grad_points += bs_idx * c * m + c_idx * m;
    idx += bs_idx * n * 3 + pt_idx * 3;


    atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
    atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
    atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}

void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 
    const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
    // grad_out: (B, C, N)
    // weight: (B, N, 3)
    // output:
    //      grad_points: (B, C, M)

    cudaError_t err;
    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);
    three_interpolate_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, m, grad_out, idx, weight, grad_points);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

================================================
FILE: model/conpu_v6/pointnet2/src/interpolate_gpu.h
================================================
#ifndef _INTERPOLATE_GPU_H
#define _INTERPOLATE_GPU_H

#include <torch/serialize/tensor.h>
#include<vector>
#include <cuda.h>
#include <cuda_runtime_api.h>


void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 
  at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);

void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
	const float *known, float *dist2, int *idx, cudaStream_t stream);


void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 
    at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);

void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 
    const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);


void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 
    at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);

void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 
    const int *idx, const float *weight, float *grad_points, cudaStream_t stream);

#endif


================================================
FILE: model/conpu_v6/pointnet2/src/pointnet2_api.cpp
================================================
#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#include "ball_query_gpu.h"
#include "group_points_gpu.h"
#include "sampling_gpu.h"
#include "interpolate_gpu.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");

    m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
    m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");

    m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
    m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");

    m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
    
    m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
    m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
    m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
}


================================================
FILE: model/conpu_v6/pointnet2/src/sampling.cpp
================================================
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>

#include "sampling_gpu.h"

extern THCState *state;


int gather_points_wrapper_fast(int b, int c, int n, int npoints, 
    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
    const float *points = points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
    return 1;
}


int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 
    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {

    const float *grad_out = grad_out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *grad_points = grad_points_tensor.data<float>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
    return 1;
}


int furthest_point_sampling_wrapper(int b, int n, int m, 
    at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {

    const float *points = points_tensor.data<float>();
    float *temp = temp_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
    return 1;
}


================================================
FILE: model/conpu_v6/pointnet2/src/sampling_gpu.cu
================================================
#include <stdio.h>
#include <stdlib.h>

#include "cuda_utils.h"
#include "sampling_gpu.h"


__global__ void gather_points_kernel_fast(int b, int c, int n, int m, 
    const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
    // points: (B, C, N)
    // idx: (B, M)
    // output:
    //      out: (B, C, M)

    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

    out += bs_idx * c * m + c_idx * m + pt_idx;
    idx += bs_idx * m + pt_idx;
    points += bs_idx * c * n + c_idx * n;
    out[0] = points[idx[0]];
}

void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 
    const float *points, const int *idx, float *out, cudaStream_t stream) {
    // points: (B, C, N)
    // idx: (B, npoints)
    // output:
    //      out: (B, C, npoints)

    cudaError_t err;
    dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    gather_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points, idx, out);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 
    const int *__restrict__ idx, float *__restrict__ grad_points) {
    // grad_out: (B, C, M)
    // idx: (B, M)
    // output:
    //      grad_points: (B, C, N)

    int bs_idx = blockIdx.z;
    int c_idx = blockIdx.y;
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;

    grad_out += bs_idx * c * m + c_idx * m + pt_idx;
    idx += bs_idx * m + pt_idx;
    grad_points += bs_idx * c * n + c_idx * n;

    atomicAdd(grad_points + idx[0], grad_out[0]);
}

void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 
    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
    // grad_out: (B, C, npoints)
    // idx: (B, npoints)
    // output:
    //      grad_points: (B, C, N)

    cudaError_t err;
    dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
    dim3 threads(THREADS_PER_BLOCK);

    gather_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, grad_out, idx, grad_points);

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
    const float v1 = dists[idx1], v2 = dists[idx2];
    const int i1 = dists_i[idx1], i2 = dists_i[idx2];
    dists[idx1] = max(v1, v2);
    dists_i[idx1] = v2 > v1 ? i2 : i1;
}

template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, 
    const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
    // dataset: (B, N, 3)
    // tmp: (B, N)
    // output:
    //      idx: (B, M)

    if (m <= 0) return;
    __shared__ float dists[block_size];
    __shared__ int dists_i[block_size];

    int batch_index = blockIdx.x;
    dataset += batch_index * n * 3;
    temp += batch_index * n;
    idxs += batch_index * m;

    int tid = threadIdx.x;
    const int stride = block_size;

    int old = 0;
    if (threadIdx.x == 0)
    idxs[0] = old;

    __syncthreads();
    for (int j = 1; j < m; j++) {
    int besti = 0;
    float best = -1;
    float x1 = dataset[old * 3 + 0];
    float y1 = dataset[old * 3 + 1];
    float z1 = dataset[old * 3 + 2];
    for (int k = tid; k < n; k += stride) {
        float x2, y2, z2;
        x2 = dataset[k * 3 + 0];
        y2 = dataset[k * 3 + 1];
        z2 = dataset[k * 3 + 2];
        // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
        // if (mag <= 1e-3)
        // continue;

        float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
        float d2 = min(d, temp[k]);
        temp[k] = d2;
        besti = d2 > best ? k : besti;
        best = d2 > best ? d2 : best;
    }
    dists[tid] = best;
    dists_i[tid] = besti;
    __syncthreads();

    if (block_size >= 1024) {
        if (tid < 512) {
            __update(dists, dists_i, tid, tid + 512);
        }
        __syncthreads();
    }

    if (block_size >= 512) {
        if (tid < 256) {
            __update(dists, dists_i, tid, tid + 256);
        }
        __syncthreads();
    }
    if (block_size >= 256) {
        if (tid < 128) {
            __update(dists, dists_i, tid, tid + 128);
        }
        __syncthreads();
    }
    if (block_size >= 128) {
        if (tid < 64) {
            __update(dists, dists_i, tid, tid + 64);
        }
        __syncthreads();
    }
    if (block_size >= 64) {
        if (tid < 32) {
            __update(dists, dists_i, tid, tid + 32);
        }
        __syncthreads();
    }
    if (block_size >= 32) {
        if (tid < 16) {
            __update(dists, dists_i, tid, tid + 16);
        }
        __syncthreads();
    }
    if (block_size >= 16) {
        if (tid < 8) {
            __update(dists, dists_i, tid, tid + 8);
        }
        __syncthreads();
    }
    if (block_size >= 8) {
        if (tid < 4) {
            __update(dists, dists_i, tid, tid + 4);
        }
        __syncthreads();
    }
    if (block_size >= 4) {
        if (tid < 2) {
            __update(dists, dists_i, tid, tid + 2);
        }
        __syncthreads();
    }
    if (block_size >= 2) {
        if (tid < 1) {
            __update(dists, dists_i, tid, tid + 1);
        }
        __syncthreads();
    }

    old = dists_i[0];
    if (tid == 0)
        idxs[j] = old;
    }
}

void furthest_point_sampling_kernel_launcher(int b, int n, int m, 
    const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
    // dataset: (B, N, 3)
    // tmp: (B, N)
    // output:
    //      idx: (B, M)

    cudaError_t err;
    unsigned int n_threads = opt_n_threads(n);

    switch (n_threads) {
        case 1024:
        furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 512:
        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 256:
        furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 128:
        furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 64:
        furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 32:
        furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 16:
        furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 8:
        furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 4:
        furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 2:
        furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 1:
        furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        default:
        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
    }

    err = cudaGetLastError();
    if (cudaSuccess != err) {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}


================================================
FILE: model/conpu_v6/pointnet2/src/sampling_gpu.h
================================================
#ifndef _SAMPLING_GPU_H
#define _SAMPLING_GPU_H

#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include<vector>


int gather_points_wrapper_fast(int b, int c, int n, int npoints, 
    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);

void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 
    const float *points, const int *idx, float *out, cudaStream_t stream);


int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 
    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);

void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 
    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);


int furthest_point_sampling_wrapper(int b, int n, int m, 
    at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);

void furthest_point_sampling_kernel_launcher(int b, int n, int m, 
    const float *dataset, float *temp, int *idxs, cudaStream_t stream);

#endif


================================================
FILE: model/conpu_v6/train_script101.py
================================================
import os

#coarse-net

loss_weight=' '
loss_weight+=' --weight_cd 1.0'
loss_weight+=' --weight_uniform -10000000'
loss_weight+=' --weight_reg -0.1'
loss_weight+=' --weight_arap 0.03'
loss_weight+=' --weight_overlap 0.3'
loss_weight+=' --weight_proj -1'
loss_weight+=' --weight_normal -1'
loss_weight+=' --weight_cycle -1'
loss_weight+=' --weight_ndirection 0.0001'


for control_i in range(0,1):
    os.system('CUDA_VISIBLE_DEVICES=1 python train_view_toy.py \
        --training_up_ratio 16 \
        --testing_up_ratio 16 \
        --over_sampling_scale 4 \
        --visualization_while_testing 1 \
        --last_sample_id '+str(control_i*10000)+' \
        --test_blank 10000 \
        --train_max_samples '+str((control_i+1)*10000)+' \
        --learning_rate '+str(0.001* 0.9**control_i)+'  \
        --batchsize 8  \
        --out_baseline \'out_baseline_101\' \
        --num_point 256 \
        --gt_num_point 4096 \
        --pack_path \'../../data/Sketchfab2/packed_data/version_2\'  \
        --over_fitting_id 0 \
        --if_over_fitting_this_time 0 \
        --if_only_test 0 \
        --if_only_test_max_num 14 \
        --network_name \'Net_conpu_v7\'  \
        --emb_dims 512 \
        --neighbor_k 10 \
        --mlp_fitting_str \'256 128 64\' \
        --pretrained \'./pre_trained/v3.pt\' \
        --if_fix_sample 0 \
        --if_use_siren 0 \
        --feature_unfolding_nei_num 4 \
        '+loss_weight)
    
    
    


================================================
FILE: model/conpu_v6/train_script101_test.py
================================================
import os

#coarse-net

loss_weight=' '
loss_weight+=' --weight_cd 1.0'
loss_weight+=' --weight_uniform -10000000'
loss_weight+=' --weight_reg -0.1'
loss_weight+=' --weight_arap 0.03'
loss_weight+=' --weight_overlap 0.3'
loss_weight+=' --weight_proj -1'
loss_weight+=' --weight_normal -1'
loss_weight+=' --weight_cycle -1'
loss_weight+=' --weight_ndirection 0.0001'


for control_i in range(0,1):
    os.system('CUDA_VISIBLE_DEVICES=1 python train_view_toy.py \
        --training_up_ratio 16 \
        --testing_up_ratio 16 \
        --over_sampling_scale 4 \
        --visualization_while_testing 1 \
        --last_sample_id '+str(control_i*10000)+' \
        --test_blank 10000 \
        --train_max_samples '+str((control_i+1)*10000)+' \
        --learning_rate '+str(0.001* 0.9**control_i)+'  \
        --batchsize 8  \
        --out_baseline \'out_baseline_101_test\' \
        --num_point 256 \
        --gt_num_point 4096 \
        --pack_path \'../../data/Sketchfab2/packed_data/version_2\'  \
        --over_fitting_id 0 \
        --if_over_fitting_this_time 0 \
        --if_only_test 1 \
        --if_only_test_max_num 14 \
        --network_name \'Net_conpu_v7\'  \
        --emb_dims 512 \
        --neighbor_k 10 \
        --mlp_fitting_str \'256 128 64\' \
        --pretrained \'./pre_trained/v3.pt\' \
        --if_fix_sample 0 \
        --if_use_siren 0 \
        --feature_unfolding_nei_num 4 \
        '+loss_weight)
    
    
    


================================================
FILE: model/conpu_v6/train_view_toy.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from  torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import math
import numpy as np
import torch.nn.init as init
import struct
import os
import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../code/')
#import drawer
import time
import mesh_operations
import torch_tensor_functions
import colormap
import random
from pointnet2 import pointnet2_utils as pn2_utils

# from torch_geometric.data import Data
# from torch_geometric.transforms.generate_mesh_normals  import *
# from torch_scatter import scatter_add


######  The network and loss are figured out here  ###### 
from loss import Loss, chamfer_dist
###### ------ ######


from utils.config import parse_args
import time
import igl
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# The parameter that controls the overfitting.
# over_fitting_id = 0
# if_over_fitting_this_time = False
# if_only_test = False
# if_only_test_max_num = 3




# Set the GradScaler
try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self):
            pass
        def scale(self, loss):
            return loss
        def unscale_(self, optimizer):
            pass
        def step(self, optimizer):
            optimizer.step()
        def update(self):
            pass

# all the args. They can be set in another .py file.
args = parse_args()
print ('args:')
print (args)


exec('from network import '+args.network_name)

over_fitting_id = args.over_fitting_id
if_over_fitting_this_time = args.if_over_fitting_this_time
if_only_test = args.if_only_test
if_only_test_max_num = args.if_only_test_max_num

# The color map for visualization.
# The points generated from a same source point should share the same color. 
rb_colormap = np.array(colormap.rb_colormap_list_little).reshape(8,3) 

batch_size=args.batchsize
# It is used to control the training process.
train_max_samples = args.train_max_samples  

# The path of the packed dataset.
pack_path=args.pack_path
print('The packed data path is : ',pack_path)

# The point number of the sparse and dense patch, respectively.
num_point = args.num_point
gt_num_point = args.gt_num_point

# The path of the training data
train_points_normals_sparse_path = pack_path+'/training_points_normals_'+str(num_point)+'.bin'
train_points_normals_dense_path = pack_path+'/training_points_normals_'+str(gt_num_point)+'.bin'

# The path of the testing data
test_points_normals_sparse_path = pack_path+'/testing_points_normals_'+str(num_point)+'.bin'
test_center_scale_sparse_path = pack_path+'/testing_center_scale_'+str(num_point)+'.bin'
test_points_normals_dense_path = pack_path+'/testing_points_normals_'+str(gt_num_point)+'.bin'

# READ in train_points_normals : sparse
train_points_normals_sparse = np.fromfile(train_points_normals_sparse_path, dtype = np.float32).reshape(-1,num_point,6)
# READ in train_points_normals : dense
train_points_normals_dense = np.fromfile(train_points_normals_dense_path, dtype = np.float32).reshape(-1,gt_num_point,6)
# READ in test_points_normals : sparse
test_points_normals_sparse = np.fromfile(test_points_normals_sparse_path, dtype = np.float32).reshape(-1,num_point,6)
test_center_scale_sparse = np.fromfile(test_center_scale_sparse_path, dtype = np.float32).reshape(-1,4)
# READ in test_points_normals : dense
test_points_normals_dense = np.fromfile(test_points_normals_dense_path, dtype = np.float32).reshape(-1,gt_num_point,6)



# The pair number of the training and testing pairs, respectively.
train_pair_num = train_points_normals_sparse.shape[0]
test_pair_num = test_points_normals_sparse.shape[0]

train_points_normals_sparse_tensor = torch.from_numpy(train_points_normals_sparse).float()
test_points_normals_sparse_tensor = torch.from_numpy(test_points_normals_sparse).float()
train_points_normals_dense_tensor = torch.from_numpy(train_points_normals_dense).float()
test_points_normals_dense_tensor = torch.from_numpy(test_points_normals_dense).float()

# All the torch-tensors used for the input and ground truth. 
train_points_sparse_tensor = train_points_normals_sparse_tensor[:,:,:3]
train_normals_sparse_tensor = train_points_normals_sparse_tensor[:,:,3:]
test_points_sparse_tensor = test_points_normals_sparse_tensor[:,:,:3]
test_normals_sparse_tensor = test_points_normals_sparse_tensor[:,:,3:]
train_points_dense_tensor = train_points_normals_dense_tensor[:,:,:3]
train_normals_dense_tensor = train_points_normals_dense_tensor[:,:,3:]
test_points_dense_tensor = test_points_normals_dense_tensor[:,:,:3]
test_normals_dense_tensor = test_points_normals_dense_tensor[:,:,3:]

# All the batch-data used for the input and ground truth. 
train_points_sparse_batch = torch.zeros([batch_size,num_point,3],dtype=torch.float,requires_grad=False).cuda()
train_normals_sparse_batch = torch.zeros([batch_size,num_point,3],dtype=torch.float,requires_grad=False).cuda()
test_points_sparse_batch = torch.zeros([batch_size,num_point,3],dtype=torch.float,requires_grad=False).cuda()
test_normals_sparse_batch = torch.zeros([batch_size,num_point,3],dtype=torch.float,requires_grad=False).cuda()
train_points_dense_batch = torch.zeros([batch_size,gt_num_point,3],dtype=torch.float,requires_grad=False).cuda()
train_normals_dense_batch = torch.zeros([batch_size,gt_num_point,3],dtype=torch.float,requires_grad=False).cuda()
test_points_dense_batch = torch.zeros([batch_size,gt_num_point,3],dtype=torch.float,requires_grad=False).cuda()
test_normals_dense_batch = torch.zeros([batch_size,gt_num_point,3],dtype=torch.float,requires_grad=False).cuda()


def update_test_cache(used_samples_num, model, loss_obj, args):
    print('updating cache for used_samples_num = ' + str(used_samples_num))
    test_cache_file='./'+args.out_baseline+'/result_cache.txt'
    loss_sum_, loss_stages_=compute_test_loss_values(model, loss_obj, args)
    print('the test loss: ',loss_sum_, loss_stages_)
    cf=open(test_cache_file,'a+')
    # The first number is the iteration times.
    cf.write(str(used_samples_num//batch_size)+' ')
    cf.write(str(loss_sum_)+' ')
    for i in range(len(loss_stages_)):
        cf.write(str(loss_stages_[i])+' ')
    cf.write('\n')
    cf.close()
    update_pics()
    if args.visualization_while_testing:
        update_visualization(model,  args)
        if if_only_test==True:exit()
    
def update_pics():
    test_cache_file='./'+args.out_baseline+'/result_cache.txt'
    cf=open(test_cache_file,'r')
    lines=cf.readlines()
    x=[]
    y_sum=[]
    y_cd=[]
    y_reg = []
    y_arap = []
    y_overlap = []
    y_proj = []
    y_normal = []
    y_ndirection = []
    for i in range(len(lines)):
        if i%1==0:
            index = int(lines[i].split(' ')[0])
            sum_loss = float(lines[i].split(' ')[1])
            cd_loss = float(lines[i].split(' ')[2])
            reg_loss = float(lines[i].split(' ')[3])
            arap_loss = float(lines[i].split(' ')[4])
            overlap_loss = float(lines[i].split(' ')[5])
            proj_loss = float(lines[i].split(' ')[6])
            normal_loss = float(lines[i].split(' ')[7])
            ndirection_loss = float(lines[i].split(' ')[8])
            iter_index=index
            x.append(iter_index)
            y_sum.append(sum_loss)
            y_cd.append(cd_loss)
            y_reg.append(reg_loss)
            y_arap.append(arap_loss)
            y_overlap.append(overlap_loss)
            y_proj.append(proj_loss)
            y_normal.append(normal_loss)
            y_ndirection.append(ndirection_loss)
    
    fig = plt.figure(0)
    fig.clear()
    plt.title('The sum loss')
    plt.xlabel('iteration')
    plt.ylabel('sum loss')
    plt.plot(x, y_sum, c='r', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_sum.png')
    
    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on cd')
    plt.xlabel('iteration')
    plt.ylabel('cd loss')
    plt.plot(x, y_cd, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_cd.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on reg')
    plt.xlabel('iteration')
    plt.ylabel('reg loss')
    plt.plot(x, y_reg, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_reg.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on arap')
    plt.xlabel('iteration')
    plt.ylabel('arap loss')
    plt.plot(x, y_arap, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_arap.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on overlap')
    plt.xlabel('iteration')
    plt.ylabel('overlap loss')
    plt.plot(x, y_overlap, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_overlap.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on proj')
    plt.xlabel('iteration')
    plt.ylabel('proj loss')
    plt.plot(x, y_proj, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_proj.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on normal')
    plt.xlabel('iteration')
    plt.ylabel('normal loss')
    plt.plot(x, y_normal, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_normal.png')

    fig = plt.figure(0)
    fig.clear()
    plt.title('The loss on ndirection')
    plt.xlabel('iteration')
    plt.ylabel('ndirection loss')
    plt.plot(x, y_ndirection, c='#526922', ls='-')
    plt.savefig('./'+args.out_baseline+'/loss_ndirection.png')
    
def update_visualization(model,  args):
    global test_center_scale_sparse

    global train_points_sparse_tensor
    global train_normals_sparse_tensor
    global test_points_sparse_tensor
    global test_normals_sparse_tensor
    global train_points_dense_tensor
    global train_normals_dense_tensor
    global test_points_dense_tensor
    global test_normals_dense_tensor

    global train_points_sparse_batch
    global train_normals_sparse_batch
    global test_points_sparse_batch
    global test_normals_sparse_batch
    global train_points_dense_batch
    global train_normals_dense_batch
    global test_points_dense_batch
    global test_normals_dense_batch
    
    test_cache_file='./'+args.out_baseline+'/result_cache.txt'
    cf=open(test_cache_file,'r')
    lines=cf.readlines()
    last_line = lines[len(lines)-1]
    iter_num = int(last_line.split(' ')[0])
    visual_folder = './'+args.out_baseline+'/visual_'+str(iter_num*batch_size)
    if not os.path.exists(visual_folder):
        os.mkdir(visual_folder)
    print('Satrt to visualize the results now.')
    # To be finished. Draw whatever you wanna observe here.
    visual_sample_num = min(batch_size,5)
    if if_only_test==True:
        visual_sample_num = test_pair_num
        testing_anchor_num = args.testing_anchor_num
        testing_model_num = test_pair_num // testing_anchor_num
        if if_only_test_max_num>=0 and if_only_test_max_num<testing_model_num: testing_model_num=if_only_test_max_num
        visual_sample_num = testing_model_num*testing_anchor_num
    if not if_over_fitting_this_time:over_fitting_id_here=0
    else:over_fitting_id_here=args.over_fitting_id
    for si in range(over_fitting_id_here, over_fitting_id_here+visual_sample_num):
        this_sample_path = visual_folder+'/sample_'+str(si)
        if not os.path.exists(this_sample_path):os.mkdir(this_sample_path)
        a_points_sparse_tensor = test_points_sparse_tensor[si:si+1].cuda()
        a_normals_sparse_tensor = test_normals_sparse_tensor[si:si+1].cuda()
        a_points_dense_tensor = test_points_dense_tensor[si:si+1].cuda()
        a_normals_dense_tensor = test_normals_dense_tensor[si:si+1].cuda()
        # get the generated results.
        model.eval()
        # with torch.no_grad():
        if True:
            a_points_gen_tensor, a_normals_gen_tensor, _, a_querying_points_3d, a_querying_points_n_3d, a_glued_points, a_glued_normals = model(a_points_sparse_tensor)
        
        # save the points : format-xyz, with normal.
        torch_tensor_functions.draw_tensor_point_xyz_with_normal(this_sample_path+'/query.xyz', a_points_gen_tensor[0].detach(),torch_tensor_normals=a_normals_gen_tensor[0].detach())
        torch_tensor_functions.draw_tensor_point_xyz_with_normal(this_sample_path+'/query_3d.xyz', a_querying_points_3d[0].detach(), torch_tensor_normals=a_querying_points_n_3d[0].detach())
        torch_tensor_functions.draw_tensor_point_xyz_with_normal(this_sample_path+'/glued.xyz', a_glued_points[0].detach())
        torch_tensor_functions.draw_tensor_point_xyz_with_normal(this_sample_path+'/sparse.xyz', a_points_sparse_tensor[0])
        torch_tensor_functions.draw_tensor_point_xyz_with_normal(this_sample_path+'/dense.xyz', a_points_dense_tensor[0])

        # the color tensor of the sparse points
        num_point_here = a_points_sparse_tensor.size()[1]
        a_points_sparse_color_tensor = torch.from_numpy(rb_colormap).float().cuda()
        while a_points_sparse_color_tensor.size()[0]<num_point_here:a_points_sparse_color_tensor = torch.cat((a_points_sparse_color_tensor,a_points_sparse_color_tensor),dim=0)
        a_points_sparse_color_tensor = a_points_sparse_color_tensor[:num_point_here]
        
        # the color tensor of the generated points
        up_ratio_here = a_points_gen_tensor.size()[1]//a_points_sparse_tensor.size()[1]
        a_points_gen_color_tensor = a_points_sparse_color_tensor.clone().view(1,-1,3)
        while a_points_gen_color_tensor.size()[0]<up_ratio_here:a_points_gen_color_tensor = torch.cat((a_points_gen_color_tensor,a_points_gen_color_tensor),dim=0)
        a_points_gen_color_tensor = a_points_gen_color_tensor[:up_ratio_here].transpose(1,0)
        a_points_gen_color_tensor = a_points_gen_color_tensor.reshape(-1,3)
        
        # save the points : format-obj, with color.
        torch_tensor_functions.draw_tensor_point_obj_with_color(this_sample_path+'/query.obj', a_points_gen_tensor[0].detach(),torch_tensor_color=a_points_gen_color_tensor)
        torch_tensor_functions.draw_tensor_point_obj_with_color(this_sample_path+'/sparse.obj', a_points_sparse_tensor[0],torch_tensor_color=a_points_sparse_color_tensor)
    
    # if if_only_test==True : Test all the testing models. 
    if if_only_test==True:
        tested_mesh_path = visual_folder + '/0tested_models'
        if not os.path.exists(tested_mesh_path):os.mkdir(tested_mesh_path)
        for model_i in range(testing_model_num):
            all_patches_points = []
            one_tested_mesh_obj_path = tested_mesh_path+'/test_model_'+str(model_i)+'.obj'
            for anchor_i in range(testing_anchor_num):
                this_sample_id = model_i*testing_anchor_num + anchor_i
                this_sample_path = visual_folder+'/sample_'+str(this_sample_id)
                v_, n_ = mesh_operations.read_xyz_(this_sample_path+'/glued.xyz')
                v_ = v_[:,:3]
                this_center_scale = test_center_scale_sparse[this_sample_id]
                this_center = this_center_scale[:3].reshape(1,3)
                this_scale = this_center_scale[3]
                v_ = v_ * this_scale
                v_ = v_ + this_center
                all_patches_points.append(v_)
            all_patches_points = np.concatenate(all_patches_points,axis=0)
            all_patches_points_torch = torch.from_numpy(all_patches_points).float().cuda().view(1,-1,3)
            fps_id = pn2_utils.furthest_point_sample(all_patches_points_torch.contiguous(), 2000*args.testing_up_ratio)
            new_xyz = pn2_utils.gather_operation(all_patches_points_torch.permute(0, 2, 1).contiguous(), fps_id)
            all_patches_points = new_xyz.permute(0,2,1).view(-1,3).cpu().numpy().astype(np.float32)
            mesh_operations.write_obj_(one_tested_mesh_obj_path, all_patches_points)
                







    
    
def stophere():
    while True:
        continue

def run_train_val(model, optimizer, loss_obj,  args):
    global train_points_sparse_tensor
    global train_normals_sparse_tensor
    global test_points_sparse_tensor
    global test_normals_sparse_tensor
    global train_points_dense_tensor
    global train_normals_dense_tensor
    global test_points_dense_tensor
    global test_normals_dense_tensor

    global train_points_sparse_batch
    global train_normals_sparse_batch
    global test_points_sparse_batch
    global test_normals_sparse_batch
    global train_points_dense_batch
    global train_normals_dense_batch
    global test_points_dense_batch
    global test_normals_dense_batch
    
    used_samples_num=args.last_sample_id
    start_pos=used_samples_num % train_pair_num

    if used_samples_num==0 or if_only_test==True:
        update_test_cache(used_samples_num, model, loss_obj,  args)
    
    while used_samples_num<train_max_samples:
        while True:
            end_pos=start_pos+batch_size
            print('Training with pair samples: '+str(start_pos)+'~'+str(end_pos))
            train_one_batch(model, optimizer, loss_obj, start_pos, end_pos, args) ############## train one batch
            used_samples_num+=end_pos-start_pos
            if used_samples_num%(args.test_blank)==0:
                update_test_cache(used_samples_num, model, loss_obj, args) ############## test once
                print('Test here, at '+str(used_samples_num))
                torch.save(model.state_dict(), './'+args.out_baseline+'/sample_'+str(used_samples_num)+'.pt')
            if end_pos>=train_pair_num:
                start_pos=end_pos - train_pair_num
            else:
                start_pos=end_pos
            print(used_samples_num,train_max_samples)
            if used_samples_num >= train_max_samples:
                break
    
    
    
def train_one_batch(model, optimizer, loss_obj, start_pos, end_pos, args):
    global train_points_sparse_tensor
    global train_normals_sparse_tensor
    global test_points_sparse_tensor
    global test_normals_sparse_tensor
    global train_points_dense_tensor
    global train_normals_dense_tensor
    global test_points_dense_tensor
    global test_normals_dense_tensor

    global train_points_sparse_batch
    global train_normals_sparse_batch
    global test_points_sparse_batch
    global test_normals_sparse_batch
    global train_points_dense_batch
    global train_normals_dense_batch
    global test_points_dense_batch
    global test_normals_dense_batch
    
    print(start_pos, end_pos)
    if end_pos<=train_pair_num:
        train_points_sparse_batch = train_points_sparse_tensor[start_pos:end_pos].cuda()
        train_normals_sparse_batch = train_normals_sparse_tensor[start_pos:end_pos].cuda()
        train_points_dense_batch = train_points_dense_tensor[start_pos:end_pos].cuda()
        train_normals_dense_batch = train_normals_dense_tensor[start_pos:end_pos].cuda()
    else:
        bottom = train_pair_num - start_pos
        top = end_pos - train_pair_num
        
        train_points_sparse_batch[:bottom] = train_points_sparse_tensor[start_pos:].cuda()
        train_normals_sparse_batch[:bottom] = train_normals_sparse_tensor[start_pos:].cuda()
        train_points_dense_batch[:bottom] = train_points_dense_tensor[start_pos:].cuda()
        train_normals_dense_batch[:bottom] = train_normals_dense_tensor[start_pos:].cuda()
        
        
        train_points_sparse_batch[bottom:] = train_points_sparse_tensor[:top].cuda()
        train_normals_sparse_batch[bottom:] = train_normals_sparse_tensor[:top].cuda()
        train_points_dense_batch[bottom:] = train_points_dense_tensor[:top].cuda()
        train_normals_dense_batch[bottom:] = train_normals_dense_tensor[:top].cuda()
    
    # For over-fitting!!
    if if_over_fitting_this_time==True:
        train_points_sparse_batch = test_points_sparse_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        train_normals_sparse_batch = test_normals_sparse_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        train_points_dense_batch = test_points_dense_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        train_normals_dense_batch = test_normals_dense_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()

        # torch_tensor_functions.draw_tensor_point_batch_xyz_with_normal('./train_sparsepoint_shows', train_points_sparse_batch, train_normals_sparse_batch)
        # torch_tensor_functions.draw_tensor_point_batch_xyz_with_normal('./train_densepoint_shows', train_points_dense_batch, train_normals_dense_batch)
    if if_over_fitting_this_time==False:
        pi_ = 3.14159265
        all_rot_matrix_ = None
        for b in range(train_points_sparse_batch.size()[0]):
            euler_x = random.randint(0,10000)/10000
            euler_y = random.randint(0,10000)/10000
            euler_z = random.randint(0,10000)/10000
            euler_angle = torch.tensor([[-pi_+2*pi_*euler_x, -pi_+2*pi_*euler_y, -pi_+2*pi_*euler_z]], dtype=torch.float32).cuda()
            a_rot_matrix_ = torch_tensor_functions.euler2rot(euler_angle)
            if b==0:all_rot_matrix_ = a_rot_matrix_
            else:all_rot_matrix_ = torch.cat((all_rot_matrix_,a_rot_matrix_),dim=0)
        train_points_sparse_batch = torch.bmm(train_points_sparse_batch, all_rot_matrix_)
        train_normals_sparse_batch = torch.bmm(train_normals_sparse_batch, all_rot_matrix_)
        train_points_dense_batch = torch.bmm(train_points_dense_batch, all_rot_matrix_)
        train_normals_dense_batch = torch.bmm(train_normals_dense_batch, all_rot_matrix_)


    for train_times in range(1):
        optimizer.zero_grad()    
        model.train()
        gen_points_batch, gen_normals_batch, uv_sampling_coors, _, _, glued_points, glued_normals = model(train_points_sparse_batch)
        
        conpu_loss, conpu_loss_stages  = loss_obj(gen_points_batch, gen_normals_batch, uv_sampling_coors, train_points_sparse_batch, train_normals_sparse_batch, train_points_dense_batch, train_normals_dense_batch)
        print('cd:',conpu_loss_stages[0])
        print('reg:',conpu_loss_stages[1])
        print('arap:',conpu_loss_stages[2])
        print('overlap:',conpu_loss_stages[3])
        print('proj:',conpu_loss_stages[4])
        print('normal:',conpu_loss_stages[5])
        print('ndirection:',conpu_loss_stages[6])
        
        model.zero_grad()
        if True:
            with torch.autograd.set_detect_anomaly(True): scaler.scale(conpu_loss).backward()
            if_have_nan = False
            if if_have_nan==False:
                scaler.unscale_(optimizer)                
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                scaler.step(optimizer)
                print('optimizer.lr : ',optimizer.state_dict()['param_groups'][0]['lr'])
                scheduler.step()
                scaler.update()
            else:
                print('The grad is dirty!!!')
        else:
            conpu_loss.backward()
            optimizer.step()


    
def test_one_batch(model, loss_obj, start_pos, end_pos, args):
    global train_points_sparse_tensor
    global train_normals_sparse_tensor
    global test_points_sparse_tensor
    global test_normals_sparse_tensor
    global train_points_dense_tensor
    global train_normals_dense_tensor
    global test_points_dense_tensor
    global test_normals_dense_tensor

    global train_points_sparse_batch
    global train_normals_sparse_batch
    global test_points_sparse_batch
    global test_normals_sparse_batch
    global train_points_dense_batch
    global train_normals_dense_batch
    global test_points_dense_batch
    global test_normals_dense_batch

    
#    model.eval()
    
#    print(start_pos, end_pos)
    if end_pos<=test_pair_num:
        test_points_sparse_batch = test_points_sparse_tensor[start_pos:end_pos].cuda()
        test_normals_sparse_batch = test_normals_sparse_tensor[start_pos:end_pos].cuda()
        test_points_dense_batch = test_points_dense_tensor[start_pos:end_pos].cuda()
        test_normals_dense_batch = test_normals_dense_tensor[start_pos:end_pos].cuda()
    else:
        bottom = test_pair_num - start_pos
        top = end_pos - test_pair_num
        
        test_points_sparse_batch[:bottom] = test_points_sparse_tensor[start_pos:].cuda()
        test_normals_sparse_batch[:bottom] = test_normals_sparse_tensor[start_pos:].cuda()
        test_points_dense_batch[:bottom] = test_points_dense_tensor[start_pos:].cuda()
        test_normals_dense_batch[:bottom] = test_normals_dense_tensor[start_pos:].cuda()
        
        
        test_points_sparse_batch[bottom:] = test_points_sparse_tensor[:top].cuda()
        test_normals_sparse_batch[bottom:] = test_normals_sparse_tensor[:top].cuda()
        test_points_dense_batch[bottom:] = test_points_dense_tensor[:top].cuda()
        test_normals_dense_batch[bottom:] = test_normals_dense_tensor[:top].cuda()
        
    # For over-fitting!!
    if if_over_fitting_this_time==True:
        test_points_sparse_batch = test_points_sparse_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        test_normals_sparse_batch = test_normals_sparse_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        test_points_dense_batch = test_points_dense_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
        test_normals_dense_batch = test_normals_dense_tensor[0+over_fitting_id:end_pos-start_pos+over_fitting_id].cuda()
    
    
    model.eval()
    # with torch.no_grad():
    if True:
        gen_points_batch, gen_normals_batch, uv_sampling_coors, _, _, glued_points, glued_normals = model(test_points_sparse_batch)

        conpu_loss, conpu_loss_stages = loss_obj(gen_points_batch, gen_normals_batch, uv_sampling_coors, test_points_sparse_batch, test_normals_sparse_batch, test_points_dense_batch, test_normals_dense_batch)

    return conpu_loss, conpu_loss_stages
    
    
    

    
def compute_test_loss_values(model, loss_obj, args):
    start_pos=0
    loss_sum=0.0
    loss_stages=[]
    batch_cnt=0.0
    print('Computing the testing loss on the testing set:')
    for s in range(0, 2, batch_size):
        start_pos = s
        end_pos = s + batch_size
        if end_pos > test_pair_num:
            end_pos = test_pair_num
        this_batch_size = end_pos - start_pos
        lsum,lstages = test_one_batch(model, loss_obj, start_pos, end_pos, args)
        if start_pos==0:
            loss_sum=lsum.item()*this_batch_size
            for i in range(len(lstages)):
                loss_stages.append(lstages[i].item()*this_batch_size)
        else:
            loss_sum+=lsum.item()*this_batch_size
            for i in range(len(lstages)):
                loss_stages[i]+=lstages[i].item()*this_batch_size
        batch_cnt += this_batch_size
    loss_sum/=batch_cnt
    for i in range(len(loss_stages)):
        loss_stages[i]/=batch_cnt
    return loss_sum, loss_stages

def show_parameter_by_name(net_name, layer_name):
    for name, parameters in net_name.named_parameters():
        if name==layer_name:
            return parameters
    return None
    
def get_para_of_one_layer_from_another_net(net_source, net_to_be_changed, layer_name):
    a = show_parameter_by_name(net_source, layer_name)
#    print(show_parameter_by_name(net_to_be_changed, layer_name))
    for name, parameters in net_to_be_changed.named_parameters():
        if name==layer_name:
            parameters.data = a.data
            return None
    print('No matching for layer: ',layer_name)
#    print(show_parameter_by_name(net_to_be_changed, layer_name))
    
    

if __name__=='__main__':
    exec('conpu_net = '+str(args.network_name)+'(args).cuda()')
    if False:
        print('#parameters:', sum(param.numel() for param in conpu_net.parameters())*4/(1024*1024),' Mb')
        exit()
    if args.last_sample_id==0:
        if os.path.exists('./'+args.out_baseline):
            os.system('rm -rf ./'+args.out_baseline)
        os.makedirs('./'+args.out_baseline)
        if len(args.pretrained)>=1: conpu_net.load_state_dict(torch.load(args.pretrained),True)
        torch.save(conpu_net.state_dict(), './'+args.out_baseline+'/sample_0.pt')
        # os.system('cp ./out_baseline_5/sample_600000.pt ./'+args.out_baseline+'/sample_0.pt')
        
    if args.if_only_test==True: conpu_net.load_state_dict(torch.load(args.pretrained),True)
    else: conpu_net.load_state_dict(torch.load('./'+args.out_baseline+'/sample_'+str(args.last_sample_id)+'.pt'),True)
    
    # setup optimizer
    optimizer = optim.AdamW(conpu_net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, eps=args.epsilon)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr = args.learning_rate, total_steps = (args.train_max_samples - args.last_sample_id)//args.batchsize, pct_start=0.03, cycle_momentum=False, anneal_strategy='linear')
    scaler = GradScaler(enabled=args.mixed_precision)
    
    # setup loss object
    loss_obj = Loss(args)
    
    # run train and test
    run_train_val(conpu_net, optimizer, loss_obj,  args)
    print('Done.')


================================================
FILE: utils/config.py
================================================
import argparse
import os
from configparser import SafeConfigParser

def parse_args():
    # argparse argument
    parser = argparse.ArgumentParser()
    parser.add_argument('--phase', default='train', help='train or test')
    parser.add_argument('--wq_test', type=int, default=0,help='if test by wq method')
    
    parser.add_argument('--device_id',help='Specify the index of the cuda device, e.g. 0, 1 ,2',default=0, type=int)
    parser.add_argument('--num_point', type=int, default=256,help='Point Number')
    parser.add_argument('--gt_num_point', type=int, default=4096,help='Point Number of GT points')
    parser.add_argument('--training_up_ratio', type=int, default=4,help='The Upsampling Ratio during training') 
    parser.add_argument('--testing_up_ratio', type=int, default=4, help='The Upsampling Ratio during testing')  
    parser.add_argument('--over_sampling_scale', type=float, default=1.5, help='The scale for over-sampling')
    parser.add_argument('--limited_testing_model_num', type=int, default=-1, help='The max allowed num of tested model')
    parser.add_argument('--emb_dims', type=int, default=8192, metavar='N',help='Dimension of embeddings')
    parser.add_argument('--testing_anchor_num', type=int, default=114, metavar='N',help='The number of patches on the testing models')
    parser.add_argument('--pe_out_L', type=int, default=5, metavar='N',help='The parameter L in the position code')
    parser.add_argument('--feature_unfolding_nei_num', type=int, default=4, metavar='N',help='The number of neighbour points used while feature unfolding')
    parser.add_argument('--repulsion_nei_num', type=int, default=5, metavar='N',help='The number of neighbour points used in repulsion loss')

    # for phase train
    parser.add_argument('--batchsize', type=int, default=8, help='Batch Size during training')
    parser.add_argument('--max_epoch', type=int, default=400, help='Epoch to run')
    parser.add_argument('--learning_rate', type=float, default=0.005)
    parser.add_argument('--reg_normal1', type=float, default=0.1)
    parser.add_argument('--reg_normal2', type=float, default=0.1)
    parser.add_argument('--jitter_sigma', type=float, default=0.01)
    parser.add_argument('--jitter_max', type=float, default=0.03)
    parser.add_argument('--if_bn', type=int, default=0, help='If using batch normalization')
    parser.add_argument('--neighbor_k', type=int, default=5, help='The number of neighbour points used in DGCNN')
    # parser.add_argument('--mlpchanels_uv_encoder_str', type=str, default='None', metavar='None',help='mlp layers of the uv position encoding (default: None)')
    parser.add_argument('--mlp_fitting_str', type=str, default='None', metavar='None',help='mlp layers of the part surface fitting (default: None)')
    parser.add_argument('--mlp_projecting_str', type=str, default='None', metavar='None',help='mlp layers of the part surface projecting (default: None)')
    # parser.add_argument('--mlp_refining_str', type=str, default='None', metavar='None',help='mlp layers of the point-wise refining (default: None)')
    # parser.add_argument('--if_refine_by_net', type=int, default=0, help='if to use the refining module in the network')
    parser.add_argument('--glue_neighbor', type=int, default=4, help='The number of neighbour points used in glue process')
    parser.add_argument('--proj_neighbor', type=int, default=4, help='The number of neighbour points used in projection process')

    # control the training
    parser.add_argument('--last_sample_id',help='the id in the last saved trained model',default=0, type=int)    
    parser.add_argument('--train_max_samples',help='the max number of samples used in the training',default=500000, type=int)
    parser.add_argument('--test_blank',help='how often the testing process is performed',default=100, type=int)
    parser.add_argument('--visualization_while_testing', default=1, type=int, metavar='visual', help='1 if visualize; 0 if not')

    # the trained results
    parser.add_argument('--pack_path', type=str, default='None', metavar='None',help='the path of packed_data (default: None)')
    parser.add_argument('--out_baseline',help='the file of the baseline training results',default='./output_baseline', type=str)  

    #for phase test
    parser.add_argument('--pretrained', default='', help='Model stored')
    parser.add_argument('--eval_xyz', default='test_5000', help='Folder to evaluate')
    parser.add_argument('--num_shape_point', type=int, default=5000,help='Point Number per shape')
    parser.add_argument('--patch_num_ratio', type=int, default=3,help='Number of points covered by patch')

    #loss terms weights
    parser.add_argument('--weight_cd', type=float, default=-1)
    parser.add_argument('--weight_refined_cd', type=float, default=-1)
    parser.add_argument('--weight_repulsion', type=float, default=-1)
    parser.add_argument('--weight_pre', type=float, default=-1)
    parser.add_argument('--weight_center', type=float, default=-1)
    parser.add_argument('--weight_exclude', type=float, default=-1)
    parser.add_argument('--weight_uniform', type=float, default=-1)
    parser.add_argument('--weight_reg', type=float, default=-1)
    parser.add_argument('--weight_arap', type=float, default=-1)
    parser.add_argument('--weight_overlap', type=float, default=-1)
    parser.add_argument('--weight_proj', type=float, default=-1)
    parser.add_argument('--weight_normal', type=float, default=-1)
    parser.add_argument('--weight_cycle', type=float, default=-1)
    parser.add_argument('--weight_ndirection', type=float, default=-1)


    parser.add_argument('--weight_decay',default=0.00005, type=float)
    parser.add_argument('--epsilon', type=float, default=1e-8)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--clip', type=float, default=1.0)

    # control the using mode
    parser.add_argument('--over_fitting_id', type=int, default=0, help='The id that you want to overfit')
    parser.add_argument('--if_over_fitting_this_time', type=int, default=0, help='whether you want to overfit, default is False')
    parser.add_argument('--if_only_test', type=int, default=0, help='whether you only want to test, default is False')
    parser.add_argument('--if_only_test_max_num', type=int, default=3, help='the max number of models that you want to test on')
    parser.add_argument('--network_name', type=str, default='Net_conpu_v1', help='the name of the network that you would like to use')
    parser.add_argument('--if_fix_sample', type=int, default=0, help='whether to use fix sampling')
    parser.add_argument('--if_use_siren', type=int, default=0, help='whether to use siren activation function')


    
    
    '''
    #basic settings
    
                                     
    # arguments for training process
    

    parser.add_argument('--patch_num', default=10, type=int,
                        metavar='pn', help='number of patches')
    parser.add_argument('--point_num', default=8192, type=int,
                        metavar='pn', help='number of patches')
    parser.add_argument('--dim_k', default=1024, type=int,
                        metavar='K', help='dim. of the feature vector (default: 1024)')
    parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
                        help='symmetric function (default: max)')
    parser.add_argument('--delta', default=1.0e-2, type=float,
                        metavar='D', help='step size for approx. Jacobian (default: 1.0e-2)')
    parser.add_argument('--learn_delta', dest='learn_delta', action='store_true',
                        help='flag for training step size delta')
    parser.add_argument('--neighbour_num', default=4, type=int,
                        metavar='nn', help='neighbour_num of weight smoothing term')
    
    
    parser.add_argument('--cycle', type=bool, default=False, metavar='N',
                        help='Whether to use cycle consistency')
    parser.add_argument('--n_blocks', type=int, default=1, metavar='N',
                        help='Num of blocks of encoder&decoder')
    parser.add_argument('--n_heads', type=int, default=1, metavar='N',
                        help='Num of heads in multiheadedattention')
    parser.add_argument('--ff_dims', type=int, default=1024, metavar='N',
                        help='Num of dimensions of fc in transformer')
    parser.add_argument('--dropout', type=float, default=0.0, metavar='N',
                        help='Dropout ratio in transformer')
                        
    # PointNet settings
    parser.add_argument('--radius', type=float, default=0.3, help='Neighborhood radius for computing pointnet features')
    parser.add_argument('--num_neighbors', type=int, default=64, metavar='N', help='Max num of neighbors to use')
    # RPMNet settings
    parser.add_argument('--features', type=str, choices=['ppf', 'dxyz', 'xyz'], default=['ppf', 'dxyz', 'xyz'],
                        nargs='+', help='Which features to use. Default: all')
    parser.add_argument('--feat_dim', type=int, default=96,
                        help='Feature dimension (to compute distances on). Other numbers will be scaled accordingly')
    parser.add_argument('--no_slack', action='store_true', help='If set, will not have a slack column.')
    parser.add_argument('--num_sk_iter', type=int, default=5,
                        help='Number of inner iterations used in sinkhorn normalization')
    parser.add_argument('--num_reg_iter', type=int, default=5,
                        help='Number of outer iterations used for registration (only during inference)')
    parser.add_argument('--loss_type', type=str, choices=['mse', 'mae'], default='mae',
                        help='Loss to be optimized')
    parser.add_argument('--wt_inliers', type=float, default=1e-2, help='Weight to encourage inliers')
                        
    parser.add_argument('--lambda_data', type=float, default=1.0, help='weight of depth loss')
    parser.add_argument('--lambda_reg', type=float, default=0.1, help='weight of regularization loss')
        
    parser.add_argument('--num_adja', type=int, default=8, help='number of nodes who affect a point')
    parser.add_argument('--max_num_edges', type=int, default=3000, help='number of edges')
    parser.add_argument('--max_num_nodes', type=int, default=400, help='number of nodes')
    parser.add_argument('--max_num_points', type=int, default=4096, help='number of points')
    '''             
    args = parser.parse_args()

    

    return args
Download .txt
gitextract_1wh9uor_/

├── README.md
├── code/
│   ├── colormap.py
│   ├── mesh_operations.py
│   └── torch_tensor_functions.py
├── model/
│   └── conpu_v6/
│       ├── chamfer_distance/
│       │   ├── __init__.py
│       │   ├── chamfer_distance.cpp
│       │   ├── chamfer_distance.cu
│       │   ├── chamfer_distance.py
│       │   └── setup.py
│       ├── loss.py
│       ├── network.py
│       ├── pointnet2/
│       │   ├── __init__.py
│       │   ├── pointnet2_modules.py
│       │   ├── pointnet2_utils.py
│       │   ├── pytorch_utils.py
│       │   ├── setup.py
│       │   └── src/
│       │       ├── ball_query.cpp
│       │       ├── ball_query_gpu.cu
│       │       ├── ball_query_gpu.h
│       │       ├── cuda_utils.h
│       │       ├── group_points.cpp
│       │       ├── group_points_gpu.cu
│       │       ├── group_points_gpu.h
│       │       ├── interpolate.cpp
│       │       ├── interpolate_gpu.cu
│       │       ├── interpolate_gpu.h
│       │       ├── pointnet2_api.cpp
│       │       ├── sampling.cpp
│       │       ├── sampling_gpu.cu
│       │       └── sampling_gpu.h
│       ├── pre_trained/
│       │   └── v3.pt
│       ├── train_script101.py
│       ├── train_script101_test.py
│       └── train_view_toy.py
└── utils/
    └── config.py
Download .txt
SYMBOL INDEX (154 symbols across 17 files)

FILE: code/mesh_operations.py
  function read_off_ (line 12) | def read_off_(off_file_name):
  function write_off_ (line 15) | def write_off_(off_file_name,v,face_=numpy.zeros((1))):
  function write_obj_ (line 28) | def write_obj_(obj_write_name,v,face_=numpy.zeros((1)),color_=numpy.zero...
  function read_obj_ (line 45) | def read_obj_(obj_write_name):
  function write_xyz_ (line 50) | def write_xyz_(xyz_write_name,v,normal_=numpy.zeros((1))):
  function read_xyz_ (line 59) | def read_xyz_(xyz_name):
  function convert_obj_to_off_ (line 78) | def convert_obj_to_off_(obj_path_in, off_path_out):
  function normalize_points_to_sphere_ (line 84) | def normalize_points_to_sphere_(v_in):
  function normalize_points_to_sphere_with_given_center_and_factor_ (line 93) | def normalize_points_to_sphere_with_given_center_and_factor_(v_in, cente...

FILE: code/torch_tensor_functions.py
  function compute_sqrdis_map (line 14) | def compute_sqrdis_map(points_x, points_y):
  function draw_tensor_point_xyz_with_normal (line 28) | def draw_tensor_point_xyz_with_normal(save_path, torch_tensor_points, to...
  function draw_tensor_point_xyz_with_normal_by_threshold (line 43) | def draw_tensor_point_xyz_with_normal_by_threshold(save_path, torch_tens...
  function draw_tensor_point_obj_with_color (line 60) | def draw_tensor_point_obj_with_color(save_path, torch_tensor_points, tor...
  function draw_tensor_point_batch_xyz_with_normal (line 75) | def draw_tensor_point_batch_xyz_with_normal(save_batch_path, torch_tenso...
  function euler2rot (line 89) | def euler2rot(euler_angle):
  function get_neighbor_index (line 115) | def get_neighbor_index(vertices: "(bs, vertice_num, 3)",  neighbor_num: ...
  function indexing_neighbor (line 127) | def indexing_neighbor(tensor: "(bs, vertice_num, dim)", index: "(bs, que...
  function indexing_by_id (line 135) | def indexing_by_id(tensor: "(bs, vertice_num, dim)", index: "(bs, query_...

FILE: model/conpu_v6/chamfer_distance/chamfer_distance.cpp
  function chamfer_distance_forward_cuda (line 27) | void chamfer_distance_forward_cuda(
  function chamfer_distance_backward_cuda (line 46) | void chamfer_distance_backward_cuda(
  function nnsearch (line 64) | void nnsearch(
  function chamfer_distance_forward (line 95) | void chamfer_distance_forward(
  function chamfer_distance_backward (line 119) | void chamfer_distance_backward(
  function PYBIND11_MODULE (line 185) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: model/conpu_v6/chamfer_distance/chamfer_distance.py
  class ChamferDistanceFunction (line 10) | class ChamferDistanceFunction(torch.autograd.Function):
    method forward (line 12) | def forward(ctx, xyz1, xyz2):
    method backward (line 36) | def backward(ctx, graddist1, graddist2):
  class ChamferDistance (line 55) | class ChamferDistance(torch.nn.Module):
    method forward (line 56) | def forward(self, xyz1, xyz2):

FILE: model/conpu_v6/loss.py
  class Loss (line 34) | class Loss(nn.Module):
    method __init__ (line 35) | def __init__(self, args):
    method loss_on_cd (line 39) | def loss_on_cd(self, deformation_p, p1):
    method loss_on_proj (line 46) | def loss_on_proj(self, p0, p1):
    method loss_on_normal (line 82) | def loss_on_normal(self, p0, p1, n0, n1):
    method loss_on_reg (line 108) | def loss_on_reg(self, gen_points_batch, train_points_sparse_batch):
    method loss_on_arap (line 122) | def loss_on_arap(self, gen_points_batch, uv_sampling_coors):
    method loss_on_overlap (line 143) | def loss_on_overlap(self, gen_points_batch, train_points_sparse_batch):
    method loss_on_ndirection (line 156) | def loss_on_ndirection(self, gen_points_batch, uv_sampling_coors, gen_...
    method forward (line 182) | def forward(self, gen_points_batch, gen_normals_batch, uv_sampling_coo...

FILE: model/conpu_v6/network.py
  class DGCNN_multi_knn_c5 (line 33) | class DGCNN_multi_knn_c5(nn.Module):
    method __init__ (line 34) | def __init__(self, emb_dims=512, args=None):
    method forward (line 52) | def forward(self, x, if_relu_atlast = False):
  function knn (line 77) | def knn(x, k):
  function get_graph_feature (line 84) | def get_graph_feature(x, k=4):
  class MLPNet_relu (line 102) | class MLPNet_relu(torch.nn.Module):
    method __init__ (line 107) | def __init__(self, nch_input, nch_layers, b_shared=True, bn_momentum=0...
    method forward (line 111) | def forward(self, inp):
  function mlp_layers_relu (line 115) | def mlp_layers_relu(nch_input, nch_layers, b_shared=True, bn_momentum=0....
  class MLPNet (line 144) | class MLPNet(torch.nn.Module):
    method __init__ (line 149) | def __init__(self, nch_input, nch_layers, b_shared=True, bn_momentum=0...
    method forward (line 153) | def forward(self, inp):
  function mlp_layers (line 157) | def mlp_layers(nch_input, nch_layers, b_shared=True, bn_momentum=0.1, dr...
  class Sine (line 182) | class Sine(nn.Module):
    method __init (line 183) | def __init(self):
    method forward (line 185) | def forward(self, input):
  class Net_conpu_v7 (line 192) | class Net_conpu_v7(nn.Module):
    method __init__ (line 193) | def __init__(self, args):
    method forward (line 213) | def forward(self, points_sparse):
    method project_3d_query_point_to_patches (line 268) | def project_3d_query_point_to_patches(self, querying_points_3d, queryi...
    method convert_uv_to_xyzn (line 341) | def convert_uv_to_xyzn(self, uv_coor, uv_coor_idx_in_sparse, sparse_em...
    method convert_uv_to_xyz (line 369) | def convert_uv_to_xyz(self, uv_coor, uv_coor_idx_in_sparse, sparse_emb...
    method convert_xyz_to_uv (line 386) | def convert_xyz_to_uv(self, xyz_coor, xyz_coor_idx_in_sparse, sparse_e...
  function convert_str_2_list (line 406) | def convert_str_2_list(str_):
  function position_encoding (line 411) | def position_encoding(input_uv, pe_out_L):
  function fix_sample (line 421) | def fix_sample(thisbatchsize, num_point, up_ratio, if_random=False):
  function uniform_random_sample (line 489) | def uniform_random_sample(thisbatchsize, num_point, up_ratio):
  function cal_grad (line 500) | def cal_grad(inputs, outputs):

FILE: model/conpu_v6/pointnet2/pointnet2_modules.py
  class _PointnetSAModuleBase (line 10) | class _PointnetSAModuleBase(nn.Module):
    method __init__ (line 12) | def __init__(self):
    method forward (line 19) | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, np...
  class PointnetSAModuleMSG (line 59) | class PointnetSAModuleMSG(_PointnetSAModuleBase):
    method __init__ (line 62) | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[...
  class PointnetSAModule (line 99) | class PointnetSAModule(PointnetSAModuleMSG):
    method __init__ (line 102) | def __init__(self, *, mlp: List[int], npoint: int = None, radius: floa...
  class PointNetSSG_Base (line 120) | class PointNetSSG_Base(PointnetSAModuleMSG):
    method __init__ (line 121) | def __init__(self, npoint, nsample, radius, in_channel, out_channel, b...
  class PointnetFPModule (line 127) | class PointnetFPModule(nn.Module):
    method __init__ (line 130) | def __init__(self, *, mlp: List[int], bn: bool = True):
    method forward (line 138) | def forward(

FILE: model/conpu_v6/pointnet2/pointnet2_utils.py
  class FurthestPointSampling (line 10) | class FurthestPointSampling(Function):
    method forward (line 12) | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
    method backward (line 32) | def backward(xyz, a=None):
  class GatherOperation (line 39) | class GatherOperation(Function):
    method forward (line 42) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.T...
    method backward (line 63) | def backward(ctx, grad_out):
  class ThreeNN (line 76) | class ThreeNN(Function):
    method forward (line 79) | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[...
    method backward (line 101) | def backward(ctx, a=None, b=None):
  class ThreeInterpolate (line 108) | class ThreeInterpolate(Function):
    method forward (line 111) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: to...
    method backward (line 134) | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch...
  class GroupingOperation (line 157) | class GroupingOperation(Function):
    method forward (line 160) | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.T...
    method backward (line 181) | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch...
  class BallQuery (line 201) | class BallQuery(Function):
    method forward (line 204) | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_x...
    method backward (line 225) | def backward(ctx, a=None):
  class QueryAndGroup (line 232) | class QueryAndGroup(nn.Module):
    method __init__ (line 233) | def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
    method forward (line 242) | def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: ...

FILE: model/conpu_v6/pointnet2/pytorch_utils.py
  class EmptyModule (line 5) | class EmptyModule(nn.Module):
    method __init__ (line 6) | def __init__(self):
    method forward (line 9) | def forward(self, x):
  class SharedResMLP (line 13) | class SharedResMLP(nn.Module):
    method __init__ (line 14) | def __init__(
    method forward (line 37) | def forward(self, x):
  class SharedMLP (line 45) | class SharedMLP(nn.Sequential):
    method __init__ (line 47) | def __init__(
  class _ConvBase (line 74) | class _ConvBase(nn.Sequential):
    method __init__ (line 76) | def __init__(
  class _BNBase (line 143) | class _BNBase(nn.Sequential):
    method __init__ (line 145) | def __init__(self, in_size, batch_norm=None, name=""):
  class BatchNorm1d (line 153) | class BatchNorm1d(_BNBase):
    method __init__ (line 155) | def __init__(self, in_size: int, *, name: str = ""):
  class BatchNorm2d (line 159) | class BatchNorm2d(_BNBase):
    method __init__ (line 161) | def __init__(self, in_size: int, name: str = ""):
  class BatchNorm3d (line 165) | class BatchNorm3d(_BNBase):
    method __init__ (line 167) | def __init__(self, in_size: int, name: str = ""):
  class Conv1d (line 171) | class Conv1d(_ConvBase):
    method __init__ (line 173) | def __init__(
  class Conv2d (line 208) | class Conv2d(_ConvBase):
    method __init__ (line 210) | def __init__(
  class Conv3d (line 244) | class Conv3d(_ConvBase):
    method __init__ (line 246) | def __init__(
  class FC (line 281) | class FC(nn.Sequential):
    method __init__ (line 283) | def __init__(

FILE: model/conpu_v6/pointnet2/src/ball_query.cpp
  function ball_query_wrapper_fast (line 14) | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,

FILE: model/conpu_v6/pointnet2/src/cuda_utils.h
  function opt_n_threads (line 10) | inline int opt_n_threads(int work_size) {

FILE: model/conpu_v6/pointnet2/src/group_points.cpp
  function group_points_grad_wrapper_fast (line 11) | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int...
  function group_points_wrapper_fast (line 25) | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsam...

FILE: model/conpu_v6/pointnet2/src/interpolate.cpp
  function three_nn_wrapper_fast (line 14) | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
  function three_interpolate_wrapper_fast (line 26) | void three_interpolate_wrapper_fast(int b, int c, int m, int n,
  function three_interpolate_grad_wrapper_fast (line 41) | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,

FILE: model/conpu_v6/pointnet2/src/pointnet2_api.cpp
  function PYBIND11_MODULE (line 10) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: model/conpu_v6/pointnet2/src/sampling.cpp
  function gather_points_wrapper_fast (line 11) | int gather_points_wrapper_fast(int b, int c, int n, int npoints,
  function gather_points_grad_wrapper_fast (line 23) | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
  function furthest_point_sampling_wrapper (line 36) | int furthest_point_sampling_wrapper(int b, int n, int m,

FILE: model/conpu_v6/train_view_toy.py
  class GradScaler (line 56) | class GradScaler:
    method __init__ (line 57) | def __init__(self):
    method scale (line 59) | def scale(self, loss):
    method unscale_ (line 61) | def unscale_(self, optimizer):
    method step (line 63) | def step(self, optimizer):
    method update (line 65) | def update(self):
  function update_test_cache (line 148) | def update_test_cache(used_samples_num, model, loss_obj, args):
  function update_pics (line 166) | def update_pics():
  function update_visualization (line 265) | def update_visualization(model,  args):
  function stophere (line 376) | def stophere():
  function run_train_val (line 380) | def run_train_val(model, optimizer, loss_obj,  args):
  function train_one_batch (line 425) | def train_one_batch(model, optimizer, loss_obj, start_pos, end_pos, args):
  function test_one_batch (line 524) | def test_one_batch(model, loss_obj, start_pos, end_pos, args):
  function compute_test_loss_values (line 588) | def compute_test_loss_values(model, loss_obj, args):
  function show_parameter_by_name (line 615) | def show_parameter_by_name(net_name, layer_name):
  function get_para_of_one_layer_from_another_net (line 621) | def get_para_of_one_layer_from_another_net(net_source, net_to_be_changed...

FILE: utils/config.py
  function parse_args (line 5) | def parse_args():
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (178K chars).
[
  {
    "path": "README.md",
    "chars": 1824,
    "preview": "# Neural Points\n【Code of CVPR 2022 paper】Neural Points: Point Cloud Representation with Neural Fields for Arbitrary Upsa"
  },
  {
    "path": "code/colormap.py",
    "chars": 2483,
    "preview": "rb_colormap_list =[ 0,         0,    0.5625,\n         0,         0,    0.6250,\n         0,         0,    0.6875,\n       "
  },
  {
    "path": "code/mesh_operations.py",
    "chars": 3392,
    "preview": "#### Author : Wanquan Feng (University of Science and Technology of China)\n#### Description : Some operations of the mes"
  },
  {
    "path": "code/torch_tensor_functions.py",
    "chars": 5967,
    "preview": "#### Author : Wanquan Feng (University of Science and Technology of China)\n#### Description : Some operations of the poi"
  },
  {
    "path": "model/conpu_v6/chamfer_distance/__init__.py",
    "chars": 46,
    "preview": "from .chamfer_distance import ChamferDistance\n"
  },
  {
    "path": "model/conpu_v6/chamfer_distance/chamfer_distance.cpp",
    "chars": 6429,
    "preview": "#include <torch/torch.h>\n\n// CUDA forward declarations\nint ChamferDistanceKernelLauncher(\n    const int b, const int n,\n"
  },
  {
    "path": "model/conpu_v6/chamfer_distance/chamfer_distance.cu",
    "chars": 5070,
    "preview": "#include <ATen/ATen.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n__global__ \nvoid ChamferDistanceKernel(\n\tint b,\n\tin"
  },
  {
    "path": "model/conpu_v6/chamfer_distance/chamfer_distance.py",
    "chars": 1913,
    "preview": "\nimport torch\n\nfrom torch.utils.cpp_extension import load\ncd = load(name=\"cd\",\n          sources=[\"chamfer_distance/cham"
  },
  {
    "path": "model/conpu_v6/chamfer_distance/setup.py",
    "chars": 375,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='chamf"
  },
  {
    "path": "model/conpu_v6/loss.py",
    "chars": 12204,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport scipy.ndimage\nimport sys\nim"
  },
  {
    "path": "model/conpu_v6/network.py",
    "chars": 33683,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torchvision import d"
  },
  {
    "path": "model/conpu_v6/pointnet2/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "model/conpu_v6/pointnet2/pointnet2_modules.py",
    "chars": 6931,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom . import pointnet2_utils\nfrom . import pytorch_"
  },
  {
    "path": "model/conpu_v6/pointnet2/pointnet2_utils.py",
    "chars": 8820,
    "preview": "import torch\nfrom torch.autograd import Variable\nfrom torch.autograd import Function\nimport torch.nn as nn\nfrom typing i"
  },
  {
    "path": "model/conpu_v6/pointnet2/pytorch_utils.py",
    "chars": 8580,
    "preview": "import torch.nn as nn\nfrom typing import List, Tuple\nimport torch.nn.functional as F\n\nclass EmptyModule(nn.Module):\n    "
  },
  {
    "path": "model/conpu_v6/pointnet2/setup.py",
    "chars": 679,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name='point"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/ball_query.cpp",
    "chars": 941,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <cuda.h>\n#include <cuda_runtime_api."
  },
  {
    "path": "model/conpu_v6/pointnet2/src/ball_query_gpu.cu",
    "chars": 2050,
    "preview": "#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"ball_query_gpu.h\"\n#include \"cuda_utils.h\"\n\n\n__global"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/ball_query_gpu.h",
    "chars": 476,
    "preview": "#ifndef _BALL_QUERY_GPU_H\n#define _BALL_QUERY_GPU_H\n\n#include <torch/serialize/tensor.h>\n#include <vector>\n#include <cud"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/cuda_utils.h",
    "chars": 353,
    "preview": "#ifndef _CUDA_UTILS_H\n#define _CUDA_UTILS_H\n\n#include <cmath>\n\n#define TOTAL_THREADS 1024\n#define THREADS_PER_BLOCK 256\n"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/group_points.cpp",
    "chars": 1174,
    "preview": "#include <torch/serialize/tensor.h>\n#include <cuda.h>\n#include <cuda_runtime_api.h>\n#include <vector>\n#include <THC/THC."
  },
  {
    "path": "model/conpu_v6/pointnet2/src/group_points_gpu.cu",
    "chars": 3307,
    "preview": "#include <stdio.h>\n#include <stdlib.h>\n\n#include \"cuda_utils.h\"\n#include \"group_points_gpu.h\"\n\n\n__global__ void group_po"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/group_points_gpu.h",
    "chars": 836,
    "preview": "#ifndef _GROUP_POINTS_GPU_H\n#define _GROUP_POINTS_GPU_H\n\n#include <torch/serialize/tensor.h>\n#include <cuda.h>\n#include "
  },
  {
    "path": "model/conpu_v6/pointnet2/src/interpolate.cpp",
    "chars": 2030,
    "preview": "#include <torch/serialize/tensor.h>\n#include <vector>\n#include <THC/THC.h>\n#include <math.h>\n#include <stdio.h>\n#include"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/interpolate_gpu.cu",
    "chars": 5331,
    "preview": "#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"cuda_utils.h\"\n#include \"interpolate_gpu.h\"\n\n\n__globa"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/interpolate_gpu.h",
    "chars": 1174,
    "preview": "#ifndef _INTERPOLATE_GPU_H\n#define _INTERPOLATE_GPU_H\n\n#include <torch/serialize/tensor.h>\n#include<vector>\n#include <cu"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/pointnet2_api.cpp",
    "chars": 1148,
    "preview": "#include <torch/serialize/tensor.h>\n#include <torch/extension.h>\n\n#include \"ball_query_gpu.h\"\n#include \"group_points_gpu"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/sampling.cpp",
    "chars": 1552,
    "preview": "#include <torch/serialize/tensor.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <vector>\n#include <THC/THC.h>\n\n#include "
  },
  {
    "path": "model/conpu_v6/pointnet2/src/sampling_gpu.cu",
    "chars": 7934,
    "preview": "#include <stdio.h>\n#include <stdlib.h>\n\n#include \"cuda_utils.h\"\n#include \"sampling_gpu.h\"\n\n\n__global__ void gather_point"
  },
  {
    "path": "model/conpu_v6/pointnet2/src/sampling_gpu.h",
    "chars": 1045,
    "preview": "#ifndef _SAMPLING_GPU_H\n#define _SAMPLING_GPU_H\n\n#include <torch/serialize/tensor.h>\n#include <ATen/cuda/CUDAContext.h>\n"
  },
  {
    "path": "model/conpu_v6/train_script101.py",
    "chars": 1449,
    "preview": "import os\n\n#coarse-net\n\nloss_weight=' '\nloss_weight+=' --weight_cd 1.0'\nloss_weight+=' --weight_uniform -10000000'\nloss_"
  },
  {
    "path": "model/conpu_v6/train_script101_test.py",
    "chars": 1454,
    "preview": "import os\n\n#coarse-net\n\nloss_weight=' '\nloss_weight+=' --weight_cd 1.0'\nloss_weight+=' --weight_uniform -10000000'\nloss_"
  },
  {
    "path": "model/conpu_v6/train_view_toy.py",
    "chars": 29207,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torchvision import d"
  },
  {
    "path": "utils/config.py",
    "chars": 10670,
    "preview": "import argparse\nimport os\nfrom configparser import SafeConfigParser\n\ndef parse_args():\n    # argparse argument\n    parse"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the WanquanF/NeuralPoints GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (166.5 KB), approximately 47.8k tokens, and a symbol index with 154 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!