Full Code of Ghostish/Open3DSOT for AI

main f08451ddf133 cached
50 files
128.7 MB
52.8k tokens
286 symbols
1 requests
Download .txt
Showing preview only (217K chars total). Download the full file or copy to clipboard to get everything.
Repository: Ghostish/Open3DSOT
Branch: main
Commit: f08451ddf133
Files: 50
Total size: 128.7 MB

Directory structure:
gitextract_33cs0j0l/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── cfgs/
│   ├── BAT_CAR_NUSCENES.yaml
│   ├── BAT_Car.yaml
│   ├── BAT_Car_Waymo.yaml
│   ├── BAT_PEDESTRIAN_NUSCENES.yaml
│   ├── BAT_Pedestrian.yaml
│   ├── M2_Track_nuscene.yaml
│   ├── M2_Track_waymo.yaml
│   ├── M2_track_kitti.yaml
│   ├── P2B_Car.yaml
│   ├── P2B_Car_NuScenes.yaml
│   └── P2B_Car_Waymo.yaml
├── datasets/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── data_classes.py
│   ├── generate_waymo_sot.py
│   ├── kitti.py
│   ├── nuscenes_data.py
│   ├── points_utils.py
│   ├── sampler.py
│   ├── searchspace.py
│   ├── utils.py
│   └── waymo_data.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone/
│   │   └── pointnet.py
│   ├── base_model.py
│   ├── bat.py
│   ├── head/
│   │   ├── rpn.py
│   │   └── xcorr.py
│   ├── m2track.py
│   └── p2b.py
├── pointnet2/
│   ├── __init__.py
│   └── utils/
│       ├── __init__.py
│       ├── linalg_utils.py
│       ├── pointnet2_modules.py
│       ├── pointnet2_utils.py
│       └── pytorch_utils.py
├── pretrained_models/
│   ├── bat_kitti_car.ckpt
│   ├── bat_kitti_pedestrian.ckpt
│   ├── bat_nuscenes_car.ckpt
│   ├── mmtrack_kitti_car.ckpt
│   ├── mmtrack_kitti_pedestrian.ckpt
│   └── mmtrack_nuscenes_car.ckpt
├── requirement.txt
└── utils/
    ├── __init__.py
    └── metrics.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
# Auto detect text files and perform LF normalization
* text=auto


================================================
FILE: .gitignore
================================================
.DS_Store
.idea/
*/.DS_Store
*.pyc
events.*
lightning_logs

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 Kangel Zenn

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Open3DSOT
A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning.

The official code release of **[BAT](https://arxiv.org/abs/2108.04728)** and **[M2 Track](https://ghostish.github.io/MM-Track/)**.


### Features
+ Modular design. It is easy to config the model and training/testing behaviors through just a `.yaml` file.
+ DDP support for both training and testing.
+ Support all common tracking datasets (KITTI, NuScenes, Waymo Open Dataset).
### :mega:  The extension of M2-Track is accepted by TPAMI! :point_down:
+ [An Effective Motion-Centric Paradigm for 3D Single Object Tracking in Point Clouds
](https://arxiv.org/abs/2303.12535)

+ Codes are coming soon.

### :mega:  One tracking paper is accepted by CVPR2022 (Oral)! :point_down:
+ [Beyond 3D Siamese Tracking: A Motion-Centric Paradigm for 3D Single Object Tracking in Point Clouds](https://arxiv.org/abs/2203.01730)

### :mega: The codes for M2-Track is now available.
+ The checkpoints we provide here achieve **better** performances than those reported in our main paper. Check below for more details.
+ The supplementary material is now out. Please check [this](https://openaccess.thecvf.com/content/CVPR2022/supplemental/Zheng_Beyond_3D_Siamese_CVPR_2022_supplemental.pdf) for more implementation details.
## Trackers
This repository includes the implementation of the following models:

### M2-Track (CVPR2022 Oral)
**[[Paper]](http://arxiv.org/abs/2203.01730)** **[[Project Page]](https://ghostish.github.io/MM-Track/)**

**M2-Track** is the first **motion-centric tracker** in LiDAR SOT, which robustly handles distractors and drastic appearance changes in complex driving scenes.  Unlike previous methods, M2-Track is a **matching-free** two-stage tracker which localizes the targets by explicitly modeling the "relative target motion" among frames.

<p align="center">
<img src="figures/mmtrack.png" width="800"/>
</p>

<p align="center">
<img src="figures/results_mmtrack.gif" width="800"/>
</p>

### BAT (ICCV2021)
**[[Paper]](https://arxiv.org/abs/2108.04728) [[Results]](./README.md#Reproduction)**

Official implementation of **BAT**. BAT uses the BBox information to compensate the information loss of incomplete scans. It augments the target template with box-aware features that efficiently and effectively improve appearance matching.

<p align="center">
<img src="figures/bat.png" width="800"/>
</p>
<p align="center">
<img src="figures/results.gif" width="800"/>
</p>

### P2B (CVPR2020)
**[[Paper]](https://arxiv.org/abs/2005.13888) [[Official implementation]](https://github.com/HaozheQi/P2B)**

Third party implementation of **P2B**. Our implementation achieves better results than the official code release. P2B adapts SiamRPN to 3D point clouds by integrating a pointwise correlation operator with a point-based RPN (VoteNet).

<p align="center">
<img src="figures/p2b.png" width="800"/>
</p>

## Setup
Installation
+ Create the environment
  ```
  git clone https://github.com/Ghostish/Open3DSOT.git
  cd Open3DSOT
  conda create -n Open3DSOT  python=3.8
  conda activate Open3DSOT
  ```
+ Install pytorch
  ```
  conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
  ```
  Our code is compatible with other PyTorch/CUDA versions. You can follow [this](https://pytorch.org/get-started/locally/) to install another version of pytorch. **Note: In order to reproduce the reported results with the provided checkpoints of BAT, please use CUDA 10.x.** 

+ Install other dependencies:
  ```
  pip install -r requirement.txt
  ```


KITTI dataset
+ Download the data for [velodyne](http://www.cvlibs.net/download.php?file=data_tracking_velodyne.zip), [calib](http://www.cvlibs.net/download.php?file=data_tracking_calib.zip) and [label_02](http://www.cvlibs.net/download.php?file=data_tracking_label_2.zip) from [KITTI Tracking](http://www.cvlibs.net/datasets/kitti/eval_tracking.php).
+ Unzip the downloaded files.
+ Put the unzipped files under the same folder as following.
  ```
  [Parent Folder]
  --> [calib]
      --> {0000-0020}.txt
  --> [label_02]
      --> {0000-0020}.txt
  --> [velodyne]
      --> [0000-0020] folders with velodynes .bin files
  ```

NuScenes dataset
+ Download the dataset from the [download page](https://www.nuscenes.org/download)
+ Extract the downloaded files and make sure you have the following structure:
  ```
  [Parent Folder]
    samples	-	Sensor data for keyframes.
    sweeps	-	Sensor data for intermediate frames.
    maps	        -	Folder for all map files: rasterized .png images and vectorized .json files.
    v1.0-*	-	JSON tables that include all the meta data and annotations. Each split (trainval, test, mini) is provided in a separate folder.
  ```
>Note: We use the **train_track** split to train our model and test it with the **val** split. Both splits are officially provided by NuScenes. During testing, we ignore the sequences where there is no point in the first given bbox.

Waymo dataset
+ Download and prepare dataset by the instruction of [CenterPoint](https://github.com/tianweiy/CenterPoint/blob/master/docs/WAYMO.md).
  ```
  [Parent Folder]
    tfrecord_training	                    
    tfrecord_validation	                 
    train 	                                    -	all training frames and annotations 
    val   	                                    -	all validation frames and annotations 
    infos_train_01sweeps_filter_zero_gt.pkl
    infos_val_01sweeps_filter_zero_gt.pkl
  ```
+ Prepare SOT dataset. Data from specific category and split will be merged (e.g., sot_infos_vehicle_train.pkl).
```bash
  python datasets/generate_waymo_sot.py
```

## Quick Start
### Training
To train a model, you must specify the `.yaml` file with `--cfg` argument. The `.yaml` file contains all the configurations of the dataset and the model. We provide `.yaml` files under the [*cfgs*](./cfgs) directory. **Note:** Before running the code, you will need to edit the `.yaml` file by setting the `path` argument as the correct root of the dataset.
```bash
CUDA_VISIBLE_DEVICES=0,1 python main.py  --cfg cfgs/M2_track_kitti.yaml  --batch_size 64 --epoch 60 --preloading
```
For M2-Track, we use the same configuration for all categories. By default, the `.yaml` is used to trained a Car tracker. You need to change the `category_name` in the `.yaml` file to train for another category.

In this version, we remove the `--gpus` flag. And all the available GPUs will be used by default. You can use `CUDA_VISIBLE_DEVICES` to select specific GPUs.

After you start training, you can start Tensorboard to monitor the training process:
```
tensorboard --logdir=./ --port=6006
```
By default, the trainer runs a full evaluation on the full test split after training every epoch. You can set `--check_val_every_n_epoch` to a larger number to speed up the training. The `--preloading` flag is used to preload the training samples into the memory to save traning time. Remove this flag if you don't have enough memory.
### Testing
To test a trained model, specify the checkpoint location with `--checkpoint` argument and send the `--test` flag to the command.
```bash
python main.py  --cfg cfgs/M2_track_kitti.yaml  --checkpoint /path/to/checkpoint/xxx.ckpt --test
```

## Reproduction
| Model | Category | Success| Precision| Checkpoint
|--|--|--|--|--|
| BAT-KITTI | Car	|65.37 | 78.88|pretrained_models/bat_kitti_car.ckpt
| BAT-NuScenes | Car	|40.73 | 43.29|pretrained_models/bat_nuscenes_car.ckpt
| BAT-KITTI | Pedestrian | 45.74| 74.53| pretrained_models/bat_kitti_pedestrian.ckpt
| M2Track-KITTI | Car | **67.43**| **81.04**| pretrained_models/mmtrack_kitti_car.ckpt
| M2Track-KITTI | Pedestrian | **60.61**| **89.39**| pretrained_models/mmtrack_kitti_pedestrian.ckpt
| M2Track-NuScenes | Car | **57.22**| **65.72**| pretrained_models/mmtrack_nuscenes_car.ckpt

Trained models are provided in the  [*pretrained_models*](./pretrained_models) directory. To reproduce the results, simply run the code with the corresponding `.yaml` file and checkpoint. For example, to reproduce the tracking results on KITTI Car of M2-Track, just run:
```bash
python main.py  --cfg cfgs/M2_track_kitti.yaml  --checkpoint ./pretrained_models/mmtrack_kitti_car.ckpt --test
```
The reported results of M2-Track checkpoints are produced on 3090/3080ti GPUs. Due to the precision issues, there could be minor differences if you test them with other GPUs.

## Acknowledgment
+ This repo is built upon [P2B](https://github.com/HaozheQi/P2B) and [SC3D](https://github.com/SilvioGiancola/ShapeCompletion3DTracking).
+ Thank Erik Wijmans for his pytorch implementation of [PointNet++](https://github.com/erikwijmans/Pointnet2_PyTorch)

## License
This repository is released under MIT License (see LICENSE file for details).


================================================
FILE: cfgs/BAT_CAR_NUSCENES.yaml
================================================
#data
dataset: nuscenes
path: #put data root here
version: v1.0-trainval
category_name: Car
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
up_axis: [ 0,0,1 ]
preload_offset: 10
key_frame_only: True
train_split: train_track
val_split: val
test_split: val
min_points: 1
train_type: train_siamese
data_limit_box: False

#model configuration
net_model: BAT
use_fps: True
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 100
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/BAT_Car.yaml
================================================
#data
dataset: kitti
path:  #put data root here
category_name: Car # [Car, Van, Pedestrian, Cyclist, All]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
coordinate_mode: velodyne
up_axis: [0,0,1]
train_split: train
val_split: test
test_split: test
preload_offset: 10
train_type: train_siamese
data_limit_box: False

#model configuration
net_model: BAT
use_fps: True
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: True
limit_box: False
IoU_space: 3

#training
batch_size: 50 #batch_size per gpu
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 12
lr_decay_rate: 0.2
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/BAT_Car_Waymo.yaml
================================================
#data
dataset: waymo
path:  #put the root of the dataset here
category_name: Vehicle # [Vehicle, Pedestrian, Cyclist]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
coordinate_mode: velodyne
up_axis: [0,0,1]
train_split: train
val_split: test
test_split: test
preload_offset: 10
tiny: False # for debug only
train_type: train_siamese
data_limit_box: False

#model configuration
net_model: BAT
use_fps: True
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 50 #batch_size per gpu
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 5
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/BAT_PEDESTRIAN_NUSCENES.yaml
================================================
#data
dataset: nuscenes
path: #put data root here
version: v1.0-trainval
category_name: Pedestrian
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
up_axis: [ 0,0,1 ]
preload_offset: 10
key_frame_only: False
train_split: train_track
val_split: val
test_split: val
min_points: 1
train_type: train_siamese

#model configuration
net_model: BAT
use_fps: True
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 100
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/BAT_Pedestrian.yaml
================================================
#data
dataset: kitti
path: #put the root of the dataset here
category_name: Pedestrian # [Car, Van, Pedestrian, Cyclist, All]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
coordinate_mode: velodyne
up_axis: [0,0,1]
train_split: train
val_split: test
test_split: test
preload_offset: 10
train_type: train_siamese
data_limit_box: True

#model configuration
net_model: BAT
use_fps: True
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 50 #batch_size per gpu
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 12
lr_decay_rate: 0.2
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/M2_Track_nuscene.yaml
================================================
#data
dataset: nuscenes
path:  #put data root here
version: v1.0-trainval
category_name: Car
bb_scale: 1.25
bb_offset: 2
point_sample_size: 1024
degrees: False
coordinate_mode: velodyne
up_axis: [0,0,1]
preload_offset: 10
data_limit_box: True
key_frame_only: True
train_split: train_track
val_split: val
test_split: val
train_type: train_motion
num_candidates: 4
motion_threshold: 0.15
use_augmentation: True


#model configuration
net_model: m2track
box_aware: True

#loss configuration
center_weight: 2
angle_weight: 10.0
seg_weight: 0.1
bc_weight: 1

motion_cls_seg_weight: 0.1


# testing config
use_z: True
limit_box: False
IoU_space: 3

#training
batch_size: 100
workers: 10
epoch: 180
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/M2_Track_waymo.yaml
================================================
#data
dataset: waymo
path: #put data root here
category_name: Vehicle # [Vehicle, Pedestrian, Cyclist]
bb_scale: 1.25
bb_offset: 2
point_sample_size: 1024
degrees: False
coordinate_mode: velodyne
up_axis: [ 0,0,1 ]

preload_offset: 60
data_limit_box: True
train_split: train
val_split: test
test_split: test
train_type: train_motion
num_candidates: 4
motion_threshold: 0.15
use_augmentation: True


box_aware: True

tiny: False # for debug only

#model configuration
net_model: m2track

#loss configuration
center_weight: 2
angle_weight: 10.0
seg_weight: 0.1
bc_weight: 1.0
motion_cls_seg_weight: 0.1


# testing config
use_z: True
limit_box: False
IoU_space: 3

#training
batch_size: 100
workers: 10
epoch: 180
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/M2_track_kitti.yaml
================================================
#data
dataset: kitti
path: #put data root here
category_name: Pedestrian # [Car, Van, Pedestrian, Cyclist, All]
bb_scale: 1.25
bb_offset: 2
point_sample_size: 1024
degrees: False
coordinate_mode: velodyne
up_axis: [ 0,0,1 ]
preload_offset: 10
data_limit_box: True
train_split: train
val_split: test
test_split: test
train_type: train_motion
num_candidates: 4
motion_threshold: 0.15
use_augmentation: True


#model configuration
net_model: m2track
box_aware: True

#loss configuration
center_weight: 2
angle_weight: 10.0
seg_weight: 0.1
bc_weight: 1

motion_cls_seg_weight: 0.1

# testing config
use_z: True
limit_box: False
IoU_space: 3

#training
batch_size: 100
workers: 10
epoch: 180
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/P2B_Car.yaml
================================================
#data
dataset: kitti
path:  #put the root of the dataset here
category_name: Car # [Car, Van, Pedestrian, Cyclist, All]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
num_candidates: 4
coordinate_mode: camera
up_axis: [0,-1,0]
train_split: train
val_split: test
test_split: test
preload_offset: 10
train_type: train_siamese
data_limit_box: True

#model configuration
net_model: P2B
use_fps: False
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 50
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/P2B_Car_NuScenes.yaml
================================================
#data
dataset: nuscenes
path: #put data root here
version: v1.0-trainval
category_name: Car # [Car, Van, Pedestrian, Cyclist, All]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
num_candidates: 4
up_axis: [ 0,0,1 ]
preload_offset: 10
key_frame_only: True
train_split: train_track
val_split: val
test_split: val
min_points: 1
train_type: train_siamese
data_limit_box: True

#model configuration
net_model: P2B
use_fps: False
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 50
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 20
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: cfgs/P2B_Car_Waymo.yaml
================================================
#data
dataset: waymo
path:  #put the root of the dataset here
category_name: Vehicle # [Vehicle, Pedestrian, Cyclist]
search_bb_scale: 1.25
search_bb_offset: 2
model_bb_scale: 1.25
model_bb_offset: 0
template_size: 512
search_size: 1024
random_sample: False
sample_per_epoch: -1
degrees: True # use degrees or radians
box_aware: True
num_candidates: 4
coordinate_mode: velodyne
up_axis: [0,0,1]
train_split: train
val_split: test
test_split: test
preload_offset: 10
tiny: False # for debug only
train_type: train_siamese
data_limit_box: True

#model configuration
net_model: P2B
use_fps: False
normalize_xyz: False
feature_channel: 256 #the output channel of backbone
hidden_channel: 256 #the hidden channel of xcorr
out_channel: 256 #the output channel of xcorr
vote_channel: 256 #the channel for vote aggregation
num_proposal: 64
k: 4
use_search_bc: False
use_search_feature: False
bc_channel: 9

#loss configuration
objectiveness_weight: 1.5
box_weight: 0.2
vote_weight: 1.0
seg_weight: 0.2
bc_weight: 1.0

# testing config
reference_BB: previous_result
shape_aggregation: firstandprevious
use_z: False
limit_box: True
IoU_space: 3

#training
batch_size: 50 #batch_size per gpu
workers: 10
epoch: 60
from_epoch: 0
lr: 0.001
optimizer: Adam
lr_decay_step: 5
lr_decay_rate: 0.1
wd: 0
gradient_clip_val: 0.0

================================================
FILE: datasets/__init__.py
================================================
"""
___init__.py
Created by zenn at 2021/7/18 15:50
"""
from datasets import kitti, sampler, nuscenes_data, waymo_data


def get_dataset(config, type='train', **kwargs):
    if config.dataset == 'kitti':
        data = kitti.kittiDataset(path=config.path,
                                  split=kwargs.get('split', 'train'),
                                  category_name=config.category_name,
                                  coordinate_mode=config.coordinate_mode,
                                  preloading=config.preloading,
                                  preload_offset=config.preload_offset if type != 'test' else -1)
    elif config.dataset == 'nuscenes':
        data = nuscenes_data.NuScenesDataset(path=config.path,
                                             split=kwargs.get('split', 'train_track'),
                                             category_name=config.category_name,
                                             version=config.version,
                                             key_frame_only=True if type != 'test' else config.key_frame_only,
                                             # can only use keyframes for training
                                             preloading=config.preloading,
                                             preload_offset=config.preload_offset if type != 'test' else -1,
                                             min_points=1 if kwargs.get('split', 'train_track') in
                                                             [config.val_split, config.test_split] else -1)
    elif config.dataset == 'waymo':
        data = waymo_data.WaymoDataset(path=config.path,
                                       split=kwargs.get('split', 'train'),
                                       category_name=config.category_name,
                                       preloading=config.preloading,
                                       preload_offset=config.preload_offset,
                                       tiny=config.tiny)
    else:
        data = None

    if type == 'train_siamese':
        return sampler.PointTrackingSampler(dataset=data,
                                            random_sample=config.random_sample,
                                            sample_per_epoch=config.sample_per_epoch,
                                            config=config)
    elif type.lower() == 'train_motion':
        return sampler.MotionTrackingSampler(dataset=data,
                                             config=config)
    else:
        return sampler.TestTrackingSampler(dataset=data, config=config)


================================================
FILE: datasets/base_dataset.py
================================================
""" 
base_dataset.py
Created by zenn at 2021/9/1 22:16
"""


class BaseDataset:
    def __init__(self, path, split, category_name="Car", **kwargs):
        self.path = path
        self.split = split
        self.category_name = category_name
        self.preloading = kwargs.get('preloading', False)


    def get_num_tracklets(self):
        raise NotImplementedError

    def get_num_frames_total(self):
        raise NotImplementedError

    def get_num_frames_tracklet(self, tracklet_id):
        raise NotImplementedError

    def get_frames(self, seq_id, frame_ids):
        raise NotImplementedError


================================================
FILE: datasets/data_classes.py
================================================
# nuScenes dev-kit.
# Code written by Oscar Beijbom, 2018.
# Licensed under the Creative Commons [see licence.txt]

#from __future__ import annotations
import torch
import numpy as np
from pyquaternion import Quaternion


class PointCloud:

    def __init__(self, points):
        """
        Class for manipulating and viewing point clouds.
        :param points: <np.float: 4, n>. Input point cloud matrix.
        """
        self.points = points
        if self.points.shape[0] > 3:
            self.points = self.points[0:3, :]

    @staticmethod
    def load_pcd_bin(file_name):
        """
        Loads from binary format. Data is stored as (x, y, z, intensity, ring index).
        :param file_name: <str>.
        :return: <np.float: 4, n>. Point cloud matrix (x, y, z, intensity).
        """
        scan = np.fromfile(file_name, dtype=np.float32)
        points = scan.reshape((-1, 5))[:, :4]
        return points.T

    @classmethod
    def from_file(cls, file_name):
        """
        Instantiate from a .pcl, .pdc, .npy, or .bin file.
        :param file_name: <str>. Path of the pointcloud file on disk.
        :return: <PointCloud>.
        """

        if file_name.endswith('.bin'):
            points = cls.load_pcd_bin(file_name)
        elif file_name.endswith('.npy'):
            points = np.load(file_name)
        else:
            raise ValueError('Unsupported filetype {}'.format(file_name))

        return cls(points)

    def nbr_points(self):
        """
        Returns the number of points.
        :return: <int>. Number of points.
        """
        return self.points.shape[1]

    def subsample(self, ratio):
        """
        Sub-samples the pointcloud.
        :param ratio: <float>. Fraction to keep.
        :return: <None>.
        """
        selected_ind = np.random.choice(np.arange(0, self.nbr_points()),
                                        size=int(self.nbr_points() * ratio))
        self.points = self.points[:, selected_ind]

    def remove_close(self, radius):
        """
        Removes point too close within a certain radius from origin.
        :param radius: <float>.
        :return: <None>.
        """

        x_filt = np.abs(self.points[0, :]) < radius
        y_filt = np.abs(self.points[1, :]) < radius
        not_close = np.logical_not(np.logical_and(x_filt, y_filt))
        self.points = self.points[:, not_close]

    def translate(self, x):
        """
        Applies a translation to the point cloud.
        :param x: <np.float: 3, 1>. Translation in x, y, z.
        :return: <None>.
        """
        for i in range(3):
            self.points[i, :] = self.points[i, :] + x[i]

    def rotate(self, rot_matrix):
        """
        Applies a rotation.
        :param rot_matrix: <np.float: 3, 3>. Rotation matrix.
        :return: <None>.
        """
        self.points[:3, :] = np.dot(rot_matrix, self.points[:3, :])

    def transform(self, transf_matrix):
        """
        Applies a homogeneous transform.
        :param transf_matrix: <np.float: 4, 4>. Homogenous transformation matrix.
        :return: <None>.
        """
        self.points[:3, :] = transf_matrix.dot(
            np.vstack((self.points[:3, :], np.ones(self.nbr_points()))))[:3, :]

    def convertToPytorch(self):
        """
        Helper from pytorch.
        :return: Pytorch array of points.
        """
        return torch.from_numpy(self.points)

    @staticmethod
    def fromPytorch(cls, pytorchTensor):
        """
        Loads from binary format. Data is stored as (x, y, z, intensity, ring index).
        :param pyttorchTensor: <Tensor>.
        :return: <np.float: 4, n>. Point cloud matrix (x, y, z, intensity).
        """
        points = pytorchTensor.numpy()
        # points = points.reshape((-1, 5))[:, :4]
        return cls(points)

    def normalize(self, wlh):
        normalizer = [wlh[1], wlh[0], wlh[2]]
        self.points = self.points / np.atleast_2d(normalizer).T


class Box:
    """ Simple data class representing a 3d box including, label, score and velocity. """

    def __init__(self, center, size, orientation, label=np.nan, score=np.nan, velocity=(np.nan, np.nan, np.nan),
                 name=None):
        """
        :param center: [<float>: 3]. Center of box given as x, y, z.
        :param size: [<float>: 3]. Size of box in width, length, height.
        :param orientation: <Quaternion>. Box orientation.
        :param label: <int>. Integer label, optional.
        :param score: <float>. Classification score, optional.
        :param velocity: [<float>: 3]. Box velocity in x, y, z direction.
        :param name: <str>. Box name, optional. Can be used e.g. for denote category name.
        """
        assert not np.any(np.isnan(center))
        assert not np.any(np.isnan(size))
        assert len(center) == 3
        assert len(size) == 3
        # assert type(orientation) == Quaternion

        self.center = np.array(center)
        self.wlh = np.array(size)
        self.orientation = orientation
        self.label = int(label) if not np.isnan(label) else label
        self.score = float(score) if not np.isnan(score) else score
        self.velocity = np.array(velocity)
        self.name = name

    def __eq__(self, other):
        center = np.allclose(self.center, other.center)
        wlh = np.allclose(self.wlh, other.wlh)
        orientation = np.allclose(self.orientation.elements, other.orientation.elements)
        label = (self.label == other.label) or (np.isnan(self.label) and np.isnan(other.label))
        score = (self.score == other.score) or (np.isnan(self.score) and np.isnan(other.score))
        vel = (np.allclose(self.velocity, other.velocity) or
               (np.all(np.isnan(self.velocity)) and np.all(np.isnan(other.velocity))))

        return center and wlh and orientation and label and score and vel

    def __repr__(self):
        repr_str = 'label: {}, score: {:.2f}, xyz: [{:.2f}, {:.2f}, {:.2f}], wlh: [{:.2f}, {:.2f}, {:.2f}], ' \
                   'rot axis: [{:.2f}, {:.2f}, {:.2f}], ang(degrees): {:.2f}, ang(rad): {:.2f}, ' \
                   'vel: {:.2f}, {:.2f}, {:.2f}, name: {}'

        return repr_str.format(self.label, self.score, self.center[0], self.center[1], self.center[2], self.wlh[0],
                               self.wlh[1], self.wlh[2], self.orientation.axis[0], self.orientation.axis[1],
                               self.orientation.axis[2], self.orientation.degrees, self.orientation.radians,
                               self.velocity[0], self.velocity[1], self.velocity[2], self.name)

    def encode(self):
        """
        Encodes the box instance to a JSON-friendly vector representation.
        :return: [<float>: 16]. List of floats encoding the box.
        """
        return self.center.tolist() + self.wlh.tolist() + self.orientation.elements.tolist() + [self.label] + [self.score] + self.velocity.tolist() + [self.name]

    @classmethod
    def decode(cls, data):
        """
        Instantiates a Box instance from encoded vector representation.
        :param data: [<float>: 16]. Output from encode.
        :return: <Box>.
        """
        return Box(data[0:3], data[3:6], Quaternion(data[6:10]), label=data[10], score=data[11], velocity=data[12:15],
                   name=data[15])

    @property
    def rotation_matrix(self):
        """
        Return a rotation matrix.
        :return: <np.float: (3, 3)>.
        """
        return self.orientation.rotation_matrix

    def translate(self, x):
        """
        Applies a translation.
        :param x: <np.float: 3, 1>. Translation in x, y, z direction.
        :return: <None>.
        """
        self.center += x

    def rotate(self, quaternion):
        """
        Rotates box.
        :param quaternion: <Quaternion>. Rotation to apply.
        :return: <None>.
        """
        self.center = np.dot(quaternion.rotation_matrix, self.center)
        self.orientation = quaternion * self.orientation
        self.velocity = np.dot(quaternion.rotation_matrix, self.velocity)

    def transform(self, transf_matrix):
        transformed = np.dot(transf_matrix[0:3,0:4].T, self.center)
        self.center = transformed[0:3]/transformed[3]
        self.orientation = self.orientation* Quaternion(matrix = transf_matrix[0:3,0:3])
        self.velocity = np.dot(transf_matrix[0:3,0:3], self.velocity)

    def corners(self, wlh_factor=1.0):
        """
        Returns the bounding box corners.
        :param wlh_factor: <float>. Multiply w, l, h by a factor to inflate or deflate the box.
        :return: <np.float: 3, 8>. First four corners are the ones facing forward.
            The last four are the ones facing backwards.
        """
        w, l, h = self.wlh * wlh_factor

        # 3D bounding box corners. (Convention: x points forward, y to the left, z up.)
        x_corners = l / 2 * np.array([1,  1,  1,  1, -1, -1, -1, -1])
        y_corners = w / 2 * np.array([1, -1, -1,  1,  1, -1, -1,  1])
        z_corners = h / 2 * np.array([1,  1, -1, -1,  1,  1, -1, -1])
        corners = np.vstack((x_corners, y_corners, z_corners))

        # Rotate
        corners = np.dot(self.orientation.rotation_matrix, corners)

        # Translate
        x, y, z = self.center
        corners[0, :] = corners[0, :] + x
        corners[1, :] = corners[1, :] + y
        corners[2, :] = corners[2, :] + z

        return corners

    def bottom_corners(self):
        """
        Returns the four bottom corners.
        :return: <np.float: 3, 4>. Bottom corners. First two face forward, last two face backwards.
        """
        return self.corners()[:, [2, 3, 7, 6]]


================================================
FILE: datasets/generate_waymo_sot.py
================================================
#!/usr/bin/env python
# encoding: utf-8
'''
@author: Xu Yan
@file: generate_waymo_sot.py
@time: 2021/6/17 13:17
'''
import os
import pickle
from collections import defaultdict

from tqdm import tqdm


def lood_pickle(root):
    with open(root, "rb") as f:
        file = pickle.load(f)
    return file


def generate_waymo_data(root, cla, split):
    TYPE_LIST = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST']

    print('Generate %s class for %s set' % (cla, split))
    waymo_infos_all = lood_pickle(os.path.join(root, 'infos_%s_01sweeps_filter_zero_gt.pkl' % split))

    DATA = defaultdict(list)

    for idx, frame in tqdm(enumerate(waymo_infos_all), total=len(waymo_infos_all)):
        anno = lood_pickle(os.path.join(root, frame['anno_path']))
        
        for obj in anno['objects']:
            if TYPE_LIST[obj['label']] == cla:
                if not obj['name'] in DATA:
                    DATA[obj['name']] = [
                        {
                            'PC': frame['path'],
                            'Box': obj['box'],
                            'Class': cla
                        }
                    ]
                else:
                    DATA[obj['name']] += [
                        {
                            'PC': frame['path'],
                            'Box': obj['box'],
                            'Class': cla
                        }
                    ]

    print('Save data...')
    with open(os.path.join(root, 'sot_infos_%s_%s.pkl' % (cla.lower(), split)), "wb") as f:
        pickle.dump(DATA, f)


if __name__ == '__main__':
    splits = ['train', 'val']
    classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']
    root = '/raid/databases/Waymo/'
    for split in splits:
        for cla in classes:
            generate_waymo_data(root, cla, split)


================================================
FILE: datasets/kitti.py
================================================
# Created by zenn at 2021/4/27

import copy
import random

from torch.utils.data import Dataset
from datasets.data_classes import PointCloud, Box
from pyquaternion import Quaternion
import numpy as np
import pandas as pd
import os
import warnings
import pickle
from collections import defaultdict
from datasets import points_utils, base_dataset


class kittiDataset(base_dataset.BaseDataset):
    def __init__(self, path, split, category_name="Car", **kwargs):
        super().__init__(path, split, category_name, **kwargs)
        self.KITTI_Folder = path
        self.KITTI_velo = os.path.join(self.KITTI_Folder, "velodyne")
        self.KITTI_image = os.path.join(self.KITTI_Folder, "image_02")
        self.KITTI_label = os.path.join(self.KITTI_Folder, "label_02")
        self.KITTI_calib = os.path.join(self.KITTI_Folder, "calib")
        self.scene_list = self._build_scene_list(split)
        self.velos = defaultdict(dict)
        self.calibs = {}
        self.tracklet_anno_list, self.tracklet_len_list = self._build_tracklet_anno()
        self.coordinate_mode = kwargs.get('coordinate_mode', 'velodyne')
        self.preload_offset = kwargs.get('preload_offset', -1)
        if self.preloading:
            self.training_samples = self._load_data()

    @staticmethod
    def _build_scene_list(split):
        if "TRAIN" in split.upper():  # Training SET
            if "TINY" in split.upper():
                scene_names = [0]
            else:
                scene_names = list(range(0, 17))
        elif "VALID" in split.upper():  # Validation Set
            if "TINY" in split.upper():
                scene_names = [18]
            else:
                scene_names = list(range(17, 19))
        elif "TEST" in split.upper():  # Testing Set
            if "TINY" in split.upper():
                scene_names = [19]
            else:
                scene_names = list(range(19, 21))

        else:  # Full Dataset
            scene_names = list(range(21))
        scene_names = ['%04d' % scene_name for scene_name in scene_names]
        return scene_names

    def _load_data(self):
        print('preloading data into memory')
        preload_data_path = os.path.join(self.KITTI_Folder,
                                         f"preload_kitti_{self.category_name}_{self.split}_{self.coordinate_mode}_{self.preload_offset}.dat")
        if os.path.isfile(preload_data_path):
            print(f'loading from saved file {preload_data_path}.')
            with open(preload_data_path, 'rb') as f:
                training_samples = pickle.load(f)
        else:
            print('reading from annos')
            training_samples = []
            for i in range(len(self.tracklet_anno_list)):
                frames = []
                for anno in self.tracklet_anno_list[i]:
                    frames.append(self._get_frame_from_anno(anno))
                training_samples.append(frames)
            with open(preload_data_path, 'wb') as f:
                print(f'saving loaded data to {preload_data_path}')
                pickle.dump(training_samples, f)
        return training_samples

    def get_num_scenes(self):
        return len(self.scene_list)

    def get_num_tracklets(self):
        return len(self.tracklet_anno_list)

    def get_num_frames_total(self):
        return sum(self.tracklet_len_list)

    def get_num_frames_tracklet(self, tracklet_id):
        return self.tracklet_len_list[tracklet_id]

    def _build_tracklet_anno(self):

        list_of_tracklet_anno = []
        list_of_tracklet_len = []
        for scene in self.scene_list:

            label_file = os.path.join(self.KITTI_label, scene + ".txt")

            df = pd.read_csv(
                label_file,
                sep=' ',
                names=[
                    "frame", "track_id", "type", "truncated", "occluded",
                    "alpha", "bbox_left", "bbox_top", "bbox_right",
                    "bbox_bottom", "height", "width", "length", "x", "y", "z",
                    "rotation_y"
                ])
            if self.category_name in ['Car', 'Van', 'Truck',
                                      'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
                                      'Misc']:
                df = df[df["type"] == self.category_name]
            elif self.category_name == 'All':
                df = df[(df["type"] == 'Car') |
                        (df["type"] == 'Van') |
                        (df["type"] == 'Pedestrian') |
                        (df["type"] == 'Cyclist')]
            else:
                df = df[df["type"] != 'DontCare']
            df.insert(loc=0, column="scene", value=scene)
            for track_id in df.track_id.unique():
                df_tracklet = df[df["track_id"] == track_id]
                df_tracklet = df_tracklet.sort_values(by=['frame'])
                df_tracklet = df_tracklet.reset_index(drop=True)
                tracklet_anno = [anno for index, anno in df_tracklet.iterrows()]
                list_of_tracklet_anno.append(tracklet_anno)
                list_of_tracklet_len.append((len(tracklet_anno)))

        return list_of_tracklet_anno, list_of_tracklet_len

    def get_frames(self, seq_id, frame_ids):
        if self.preloading:
            frames = [self.training_samples[seq_id][f_id] for f_id in frame_ids]
        else:
            seq_annos = self.tracklet_anno_list[seq_id]
            frames = [self._get_frame_from_anno(seq_annos[f_id]) for f_id in frame_ids]

        return frames

    def _get_frame_from_anno(self, anno):
        scene_id = anno['scene']
        frame_id = anno['frame']
        try:
            calib = self.calibs[scene_id]
        except KeyError:
            calib_path = os.path.join(self.KITTI_calib, scene_id + ".txt")
            calib = self._read_calib_file(calib_path)
            self.calibs[scene_id] = calib
        velo_to_cam = np.vstack((calib["Tr_velo_cam"], np.array([0, 0, 0, 1])))

        if self.coordinate_mode == 'velodyne':
            box_center_cam = np.array([anno["x"], anno["y"] - anno["height"] / 2, anno["z"], 1])
            # transform bb from camera coordinate into velo coordinates
            box_center_velo = np.dot(np.linalg.inv(velo_to_cam), box_center_cam)
            box_center_velo = box_center_velo[:3]
            size = [anno["width"], anno["length"], anno["height"]]
            orientation = Quaternion(
                axis=[0, 0, -1], radians=anno["rotation_y"]) * Quaternion(axis=[0, 0, -1], degrees=90)
            bb = Box(box_center_velo, size, orientation)
        else:
            center = [anno["x"], anno["y"] - anno["height"] / 2, anno["z"]]
            size = [anno["width"], anno["length"], anno["height"]]
            orientation = Quaternion(
                axis=[0, 1, 0], radians=anno["rotation_y"]) * Quaternion(
                axis=[1, 0, 0], radians=np.pi / 2)
            bb = Box(center, size, orientation)

        try:
            try:
                pc = self.velos[scene_id][frame_id]
            except KeyError:
                # VELODYNE PointCloud
                velodyne_path = os.path.join(self.KITTI_velo, scene_id,
                                             '{:06}.bin'.format(frame_id))

                pc = PointCloud(
                    np.fromfile(velodyne_path, dtype=np.float32).reshape(-1, 4).T)
                if self.coordinate_mode == "camera":
                    pc.transform(velo_to_cam)
                self.velos[scene_id][frame_id] = pc
            if self.preload_offset > 0:
                pc = points_utils.crop_pc_axis_aligned(pc, bb, offset=self.preload_offset)
        except:
            # in case the Point cloud is missing
            # (0001/[000177-000180].bin)
            # msg = f"The point cloud at scene {scene_id} frame {frame_id} is missing."
            # warnings.warn(msg)
            pc = PointCloud(np.array([[0, 0, 0]]).T)
        # todo add image
        return {"pc": pc, "3d_bbox": bb, 'meta': anno}

    @staticmethod
    def _read_calib_file(filepath):
        """Read in a calibration file and parse into a dictionary."""
        data = {}
        with open(filepath, 'r') as f:
            for line in f.readlines():
                values = line.split()
                # The only non-float values in these files are dates, which
                # we don't care about anyway
                try:
                    data[values[0]] = np.array(
                        [float(x) for x in values[1:]]).reshape(3, 4)
                except ValueError:
                    pass
        return data


================================================
FILE: datasets/nuscenes_data.py
================================================
"""
nuscenes.py
Created by zenn at 2021/9/1 15:05
"""
import os

import numpy as np
import pickle
import nuscenes
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.data_classes import LidarPointCloud, Box
from nuscenes.utils.splits import create_splits_scenes

from pyquaternion import Quaternion

from datasets import points_utils, base_dataset
from datasets.data_classes import PointCloud

general_to_tracking_class = {"animal": "void / ignore",
                             "human.pedestrian.personal_mobility": "void / ignore",
                             "human.pedestrian.stroller": "void / ignore",
                             "human.pedestrian.wheelchair": "void / ignore",
                             "movable_object.barrier": "void / ignore",
                             "movable_object.debris": "void / ignore",
                             "movable_object.pushable_pullable": "void / ignore",
                             "movable_object.trafficcone": "void / ignore",
                             "static_object.bicycle_rack": "void / ignore",
                             "vehicle.emergency.ambulance": "void / ignore",
                             "vehicle.emergency.police": "void / ignore",
                             "vehicle.construction": "void / ignore",
                             "vehicle.bicycle": "bicycle",
                             "vehicle.bus.bendy": "bus",
                             "vehicle.bus.rigid": "bus",
                             "vehicle.car": "car",
                             "vehicle.motorcycle": "motorcycle",
                             "human.pedestrian.adult": "pedestrian",
                             "human.pedestrian.child": "pedestrian",
                             "human.pedestrian.construction_worker": "pedestrian",
                             "human.pedestrian.police_officer": "pedestrian",
                             "vehicle.trailer": "trailer",
                             "vehicle.truck": "truck", }

tracking_to_general_class = {
    'void / ignore': ['animal', 'human.pedestrian.personal_mobility', 'human.pedestrian.stroller',
                      'human.pedestrian.wheelchair', 'movable_object.barrier', 'movable_object.debris',
                      'movable_object.pushable_pullable', 'movable_object.trafficcone', 'static_object.bicycle_rack',
                      'vehicle.emergency.ambulance', 'vehicle.emergency.police', 'vehicle.construction'],
    'bicycle': ['vehicle.bicycle'],
    'bus': ['vehicle.bus.bendy', 'vehicle.bus.rigid'],
    'car': ['vehicle.car'],
    'motorcycle': ['vehicle.motorcycle'],
    'pedestrian': ['human.pedestrian.adult', 'human.pedestrian.child', 'human.pedestrian.construction_worker',
                   'human.pedestrian.police_officer'],
    'trailer': ['vehicle.trailer'],
    'truck': ['vehicle.truck']}


class NuScenesDataset(base_dataset.BaseDataset):
    def __init__(self, path, split, category_name="Car", version='v1.0-trainval', **kwargs):
        super().__init__(path, split, category_name, **kwargs)
        self.nusc = NuScenes(version=version, dataroot=path, verbose=False)
        self.version = version
        self.key_frame_only = kwargs.get('key_frame_only', False)
        self.min_points = kwargs.get('min_points', False)
        self.preload_offset = kwargs.get('preload_offset', -1)
        self.track_instances = self.filter_instance(split, category_name.lower(), self.min_points)
        self.tracklet_anno_list, self.tracklet_len_list = self._build_tracklet_anno()
        if self.preloading:
            self.training_samples = self._load_data()

    def filter_instance(self, split, category_name=None, min_points=-1):
        """
        This function is used to filter the tracklets.

        split: the dataset split
        category_name:
        min_points: the minimum number of points in the first bbox
        """
        if category_name is not None:
            general_classes = tracking_to_general_class[category_name]
        instances = []
        scene_splits = nuscenes.utils.splits.create_splits_scenes()
        for instance in self.nusc.instance:
            anno = self.nusc.get('sample_annotation', instance['first_annotation_token'])
            sample = self.nusc.get('sample', anno['sample_token'])
            scene = self.nusc.get('scene', sample['scene_token'])
            instance_category = self.nusc.get('category', instance['category_token'])['name']
            if scene['name'] in scene_splits[split] and anno['num_lidar_pts'] >= min_points and \
                    (category_name is None or category_name is not None and instance_category in general_classes):
                instances.append(instance)
        return instances

    def _build_tracklet_anno(self):
        list_of_tracklet_anno = []
        list_of_tracklet_len = []
        for instance in self.track_instances:
            track_anno = []
            curr_anno_token = instance['first_annotation_token']

            while curr_anno_token != '':

                ann_record = self.nusc.get('sample_annotation', curr_anno_token)
                sample = self.nusc.get('sample', ann_record['sample_token'])
                sample_data_lidar = self.nusc.get('sample_data', sample['data']['LIDAR_TOP'])

                curr_anno_token = ann_record['next']
                if self.key_frame_only and not sample_data_lidar['is_key_frame']:
                    continue
                track_anno.append({"sample_data_lidar": sample_data_lidar, "box_anno": ann_record})

            list_of_tracklet_anno.append(track_anno)
            list_of_tracklet_len.append(len(track_anno))
        return list_of_tracklet_anno, list_of_tracklet_len

    def _load_data(self):
        print('preloading data into memory')
        preload_data_path = os.path.join(self.path,
                                         f"preload_nuscenes_{self.category_name}_{self.split}_{self.version}_{self.preload_offset}_{self.min_points}.dat")
        if os.path.isfile(preload_data_path):
            print(f'loading from saved file {preload_data_path}.')
            with open(preload_data_path, 'rb') as f:
                training_samples = pickle.load(f)
        else:
            print('reading from annos')
            training_samples = []
            for i in range(len(self.tracklet_anno_list)):
                frames = []
                for anno in self.tracklet_anno_list[i]:
                    frames.append(self._get_frame_from_anno_data(anno))
                training_samples.append(frames)
            with open(preload_data_path, 'wb') as f:
                print(f'saving loaded data to {preload_data_path}')
                pickle.dump(training_samples, f)
        return training_samples

    def get_num_tracklets(self):
        return len(self.tracklet_anno_list)

    def get_num_frames_total(self):
        return sum(self.tracklet_len_list)

    def get_num_frames_tracklet(self, tracklet_id):
        return self.tracklet_len_list[tracklet_id]

    def get_frames(self, seq_id, frame_ids):
        if self.preloading:
            frames = [self.training_samples[seq_id][f_id] for f_id in frame_ids]
        else:
            seq_annos = self.tracklet_anno_list[seq_id]
            frames = [self._get_frame_from_anno_data(seq_annos[f_id]) for f_id in frame_ids]

        return frames

    def _get_frame_from_anno_data(self, anno):
        sample_data_lidar = anno['sample_data_lidar']
        box_anno = anno['box_anno']
        bb = Box(box_anno['translation'], box_anno['size'], Quaternion(box_anno['rotation']),
                 name=box_anno['category_name'], token=box_anno['token'])
        pcl_path = os.path.join(self.path, sample_data_lidar['filename'])
        pc = LidarPointCloud.from_file(pcl_path)

        cs_record = self.nusc.get('calibrated_sensor', sample_data_lidar['calibrated_sensor_token'])
        pc.rotate(Quaternion(cs_record['rotation']).rotation_matrix)
        pc.translate(np.array(cs_record['translation']))

        poserecord = self.nusc.get('ego_pose', sample_data_lidar['ego_pose_token'])
        pc.rotate(Quaternion(poserecord['rotation']).rotation_matrix)
        pc.translate(np.array(poserecord['translation']))

        pc = PointCloud(points=pc.points)
        if self.preload_offset > 0:
            pc = points_utils.crop_pc_axis_aligned(pc, bb, offset=self.preload_offset)
        return {"pc": pc, "3d_bbox": bb, 'meta': anno}


================================================
FILE: datasets/points_utils.py
================================================
import nuscenes.utils.geometry_utils
import torch
import os
import copy
import numpy as np
from pyquaternion import Quaternion
from datasets.data_classes import PointCloud, Box
from scipy.spatial.distance import cdist


def random_choice(num_samples, size, replacement=False, seed=None):
    if seed is not None:
        generator = torch.random.manual_seed(seed)
    else:
        generator = None
    return torch.multinomial(
        torch.ones((size), dtype=torch.float32),
        num_samples=num_samples,
        replacement=replacement,
        generator=generator
    )


def regularize_pc(points, sample_size, seed=None):
    # random sampling from points
    num_points = points.shape[0]
    new_pts_idx = None
    rng = np.random if seed is None else np.random.default_rng(seed)
    if num_points > 2:
        if num_points != sample_size:
            new_pts_idx = rng.choice(num_points, size=sample_size, replace=sample_size > num_points)
            # new_pts_idx = random_choice(num_samples=sample_size, size=num_points,
            #                             replacement=sample_size > num_points, seed=seed).numpy()
        else:
            new_pts_idx = np.arange(num_points)
    if new_pts_idx is not None:
        points = points[new_pts_idx, :]
    else:
        points = np.zeros((sample_size, 3), dtype='float32')
    return points, new_pts_idx


def getOffsetBB(box, offset, degrees=True, use_z=False, limit_box=True, inplace=False):
    rot_quat = Quaternion(matrix=box.rotation_matrix)
    trans = np.array(box.center)
    if not inplace:
        new_box = copy.deepcopy(box)
    else:
        new_box = box

    new_box.translate(-trans)
    new_box.rotate(rot_quat.inverse)
    if len(offset) == 3:
        use_z = False
    # REMOVE TRANSfORM
    if degrees:
        if len(offset) == 3:
            new_box.rotate(
                Quaternion(axis=[0, 0, 1], degrees=offset[2]))
        elif len(offset) == 4:
            new_box.rotate(
                Quaternion(axis=[0, 0, 1], degrees=offset[3]))
    else:
        if len(offset) == 3:
            new_box.rotate(
                Quaternion(axis=[0, 0, 1], radians=offset[2]))
        elif len(offset) == 4:
            new_box.rotate(
                Quaternion(axis=[0, 0, 1], radians=offset[3]))
    if limit_box:
        if offset[0] > new_box.wlh[0]:
            offset[0] = np.random.uniform(-1, 1)
        if offset[1] > min(new_box.wlh[1], 2):
            offset[1] = np.random.uniform(-1, 1)
        if use_z and offset[2] > new_box.wlh[2]:
            offset[2] = 0
    if use_z:
        new_box.translate(np.array([offset[0], offset[1], offset[2]]))
    else:
        new_box.translate(np.array([offset[0], offset[1], 0]))

    # APPLY PREVIOUS TRANSFORMATION
    new_box.rotate(rot_quat)
    new_box.translate(trans)
    return new_box


def getModel(PCs, boxes, offset=0, scale=1.0, normalize=False):
    """center and merge the object pcs in boxes"""
    if len(PCs) == 0:
        return PointCloud(np.ones((3, 0)))
    points = [np.ones((PCs[0].points.shape[0], 0), dtype='float32')]
    for PC, box in zip(PCs, boxes):
        cropped_PC, new_box = cropAndCenterPC(PC, box, offset=offset, scale=scale, normalize=normalize)
        # try:
        if cropped_PC.nbr_points() > 0:
            points.append(cropped_PC.points)

    PC = PointCloud(np.concatenate(points, axis=1))
    return PC, new_box


def cropAndCenterPC(PC, box, offset=0, scale=1.0, normalize=False):
    """
    crop and center the pc using the given box
    """
    new_PC = crop_pc_axis_aligned(PC, box, offset=2 * offset, scale=4 * scale)

    new_box = copy.deepcopy(box)

    rot_mat = np.transpose(new_box.rotation_matrix)
    trans = -new_box.center

    new_PC.translate(trans)
    new_box.translate(trans)
    new_PC.rotate((rot_mat))
    new_box.rotate(Quaternion(matrix=(rot_mat)))

    # crop around box
    new_PC = crop_pc_axis_aligned(new_PC, new_box, offset=offset, scale=scale)

    if normalize:
        new_PC.normalize(box.wlh)
    return new_PC, new_box


def get_point_to_box_distance(pc, box, wlh_factor=1.0):
    """
    generate the BoxCloud for the given pc and box
    :param pc: Pointcloud object or numpy array
    :param box:
    :return:
    """
    if isinstance(pc, PointCloud):
        points = pc.points.T  # N,3
    else:
        points = pc  # N,3
        assert points.shape[1] == 3
    box_corners = box.corners(wlh_factor=wlh_factor)  # 3,8
    box_centers = box.center.reshape(-1, 1)  # 3,1
    box_points = np.concatenate([box_centers, box_corners], axis=1)  # 3,9
    points2cc_dist = cdist(points, box_points.T)  # N,9
    return points2cc_dist


def crop_pc_axis_aligned(PC, box, offset=0, scale=1.0, return_mask=False):
    """
    crop the pc using the box in the axis-aligned manner
    """
    box_tmp = copy.deepcopy(box)
    box_tmp.wlh = box_tmp.wlh * scale
    maxi = np.max(box_tmp.corners(), 1) + offset
    mini = np.min(box_tmp.corners(), 1) - offset

    x_filt_max = PC.points[0, :] < maxi[0]
    x_filt_min = PC.points[0, :] > mini[0]
    y_filt_max = PC.points[1, :] < maxi[1]
    y_filt_min = PC.points[1, :] > mini[1]
    z_filt_max = PC.points[2, :] < maxi[2]
    z_filt_min = PC.points[2, :] > mini[2]

    close = np.logical_and(x_filt_min, x_filt_max)
    close = np.logical_and(close, y_filt_min)
    close = np.logical_and(close, y_filt_max)
    close = np.logical_and(close, z_filt_min)
    close = np.logical_and(close, z_filt_max)

    new_PC = PointCloud(PC.points[:, close])
    if return_mask:
        return new_PC, close
    return new_PC


def crop_pc_oriented(PC, box, offset=0, scale=1.0, return_mask=False):
    """
    crop the pc using the exact box.
    slower than 'crop_pc_axis_aligned' but more accurate
    """

    box_tmp = copy.deepcopy(box)
    new_PC = PointCloud(PC.points.copy())
    rot_mat = np.transpose(box_tmp.rotation_matrix)
    trans = -box_tmp.center

    # align data
    new_PC.translate(trans)
    box_tmp.translate(trans)
    new_PC.rotate(rot_mat)
    box_tmp.rotate(Quaternion(matrix=rot_mat))

    box_tmp.wlh = box_tmp.wlh * scale
    maxi = np.max(box_tmp.corners(), 1) + offset
    mini = np.min(box_tmp.corners(), 1) - offset

    x_filt_max = new_PC.points[0, :] < maxi[0]
    x_filt_min = new_PC.points[0, :] > mini[0]
    y_filt_max = new_PC.points[1, :] < maxi[1]
    y_filt_min = new_PC.points[1, :] > mini[1]
    z_filt_max = new_PC.points[2, :] < maxi[2]
    z_filt_min = new_PC.points[2, :] > mini[2]

    close = np.logical_and(x_filt_min, x_filt_max)
    close = np.logical_and(close, y_filt_min)
    close = np.logical_and(close, y_filt_max)
    close = np.logical_and(close, z_filt_min)
    close = np.logical_and(close, z_filt_max)

    new_PC = PointCloud(new_PC.points[:, close])

    # transform back to the original coordinate system
    new_PC.rotate(np.transpose(rot_mat))
    new_PC.translate(-trans)
    if return_mask:
        return new_PC, close
    return new_PC


def generate_subwindow(pc, sample_bb, scale, offset=2, oriented=True):
    """
    generating the search area using the sample_bb

    :param pc:
    :param sample_bb:
    :param scale:
    :param offset:
    :param oriented: use oriented or axis-aligned cropping
    :return:
    """
    rot_mat = np.transpose(sample_bb.rotation_matrix)
    trans = -sample_bb.center
    if oriented:
        new_pc = PointCloud(pc.points.copy())
        box_tmp = copy.deepcopy(sample_bb)

        # transform to the coordinate system of sample_bb
        new_pc.translate(trans)
        box_tmp.translate(trans)
        new_pc.rotate(rot_mat)
        box_tmp.rotate(Quaternion(matrix=rot_mat))
        new_pc = crop_pc_axis_aligned(new_pc, box_tmp, scale=scale, offset=offset)


    else:
        new_pc = crop_pc_axis_aligned(pc, sample_bb, scale=scale, offset=offset)

        # transform to the coordinate system of sample_bb
        new_pc.translate(trans)
        new_pc.rotate(rot_mat)

    return new_pc


def transform_box(box, ref_box, inplace=False):
    if not inplace:
        box = copy.deepcopy(box)
    box.translate(-ref_box.center)
    box.rotate(Quaternion(matrix=ref_box.rotation_matrix.T))
    return box


def transform_pc(pc, ref_box, inplace=False):
    if not inplace:
        pc = copy.deepcopy(pc)
    pc.translate(-ref_box.center)
    pc.rotate(ref_box.rotation_matrix.T)
    return pc


def get_in_box_mask(PC, box):
    """check which points of PC are inside the box"""
    box_tmp = copy.deepcopy(box)
    new_PC = PointCloud(PC.points.copy())
    rot_mat = np.transpose(box_tmp.rotation_matrix)
    trans = -box_tmp.center

    # align data
    new_PC.translate(trans)
    box_tmp.translate(trans)
    new_PC.rotate(rot_mat)
    box_tmp.rotate(Quaternion(matrix=rot_mat))
    maxi = np.max(box_tmp.corners(), 1)
    mini = np.min(box_tmp.corners(), 1)

    x_filt_max = new_PC.points[0, :] < maxi[0]
    x_filt_min = new_PC.points[0, :] > mini[0]
    y_filt_max = new_PC.points[1, :] < maxi[1]
    y_filt_min = new_PC.points[1, :] > mini[1]
    z_filt_max = new_PC.points[2, :] < maxi[2]
    z_filt_min = new_PC.points[2, :] > mini[2]

    close = np.logical_and(x_filt_min, x_filt_max)
    close = np.logical_and(close, y_filt_min)
    close = np.logical_and(close, y_filt_max)
    close = np.logical_and(close, z_filt_min)
    close = np.logical_and(close, z_filt_max)
    return close


def apply_transform(in_box_pc, box, translation, rotation, flip_x, flip_y, rotation_axis=(0, 0, 1)):
    """
    Apply transformation to the box and its pc insides. pc should be inside the given box.
    :param in_box_pc: PointCloud object
    :param box: Box object
    :param flip_y: boolean
    :param flip_x: boolean
    :param rotation_axis: 3-element tuple. The rotation axis
    :param translation: <np.float: 3, 1>. Translation in x, y, z direction.
    :param rotation: float. rotation in degrees
    :return:
    """

    # get inverse transform
    rot_mat = box.rotation_matrix
    trans = box.center

    new_box = copy.deepcopy(box)
    new_pc = copy.deepcopy(in_box_pc)

    new_pc.translate(-trans)
    new_box.translate(-trans)
    new_pc.rotate(rot_mat.T)
    new_box.rotate(Quaternion(matrix=rot_mat.T))

    if flip_x:
        new_pc.points[0, :] = -new_pc.points[0, :]
        # rotate the box to make sure that the x-axis is point to the head
        new_box.rotate(Quaternion(axis=[0, 0, 1], degrees=180))
    if flip_y:
        new_pc.points[1, :] = -new_pc.points[1, :]

    # apply rotation
    rot_quat = Quaternion(axis=rotation_axis, degrees=rotation)
    new_box.rotate(rot_quat)
    new_pc.rotate(rot_quat.rotation_matrix)

    # apply translation
    new_box.translate(translation)
    new_pc.translate(translation)

    # transform back
    new_box.rotate(Quaternion(matrix=rot_mat))
    new_pc.rotate(rot_mat)
    new_box.translate(trans)
    new_pc.translate(trans)
    return new_pc, new_box


def apply_augmentation(pc, box, wlh_factor=1.25):
    in_box_mask = nuscenes.utils.geometry_utils.points_in_box(box, pc.points, wlh_factor=wlh_factor)
    # in_box_mask = get_in_box_mask(pc, box)
    in_box_pc = PointCloud(pc.points[:, in_box_mask])

    rand_trans = np.random.uniform(low=-0.3, high=0.3, size=3)
    rand_rot = np.random.uniform(low=-10, high=10)
    flip_x, flip_y = np.random.choice([True, False], size=2, replace=True)

    new_in_box_pc, new_box = apply_transform(in_box_pc, box, rand_trans, rand_rot, flip_x, flip_y)

    new_pc = copy.deepcopy(pc)
    new_pc.points[:, in_box_mask] = new_in_box_pc.points
    return new_pc, new_box


def roty_batch_tensor(t):
    input_shape = t.shape
    output = torch.zeros(tuple(list(input_shape) + [3, 3]), dtype=torch.float32, device=t.device)
    c = torch.cos(t)
    s = torch.sin(t)
    output[..., 0, 0] = c
    output[..., 0, 2] = s
    output[..., 1, 1] = 1
    output[..., 2, 0] = -s
    output[..., 2, 2] = c
    return output


def rotz_batch_tensor(t):
    input_shape = t.shape
    output = torch.zeros(tuple(list(input_shape) + [3, 3]), dtype=torch.float32, device=t.device)
    c = torch.cos(t)
    s = torch.sin(t)
    output[..., 0, 0] = c
    output[..., 0, 1] = -s
    output[..., 1, 0] = s
    output[..., 1, 1] = c
    output[..., 2, 2] = 1
    return output


def get_offset_points_tensor(points, ref_box_params, offset_box_params):
    """

    :param points: B,N,3
    :param ref_box_params: B,4
    :return:
    """
    ref_center = ref_box_params[:, :3]
    ref_rot_angles = ref_box_params[:, -1]
    offset_center = offset_box_params[:, :3]
    offset_rot_angles = offset_box_params[:, -1]

    # transform to object coordinate system defined by the ref_box_params
    rot_mat = rotz_batch_tensor(-ref_rot_angles)  # B,3,3
    points -= ref_center[:, None, :]  # B,N,3
    points = torch.matmul(points, rot_mat.transpose(1, 2))

    # apply the offset
    rot_mat_offset = rotz_batch_tensor(offset_rot_angles)
    points = torch.matmul(points, rot_mat_offset.transpose(1, 2))
    points += offset_center[:, None, :]

    # # transform back to world coordinate
    points = torch.matmul(points, rot_mat)
    points += ref_center[:, None, :]
    return points


def get_offset_box_tensor(ref_box_params, offset_box_params):
    """
    transform the ref_box with the give offset
    :param ref_box_params: B,4
    :param offset_box_params: B,4
    :return: B,4
    """
    ref_center = ref_box_params[:, :3]  # B,3
    ref_rot_angles = ref_box_params[:, -1]  # B,
    offset_center = offset_box_params[:, :3]
    offset_rot_angles = offset_box_params[:, -1]
    rot_mat = rotz_batch_tensor(ref_rot_angles)  # B,3,3

    new_center = torch.matmul(rot_mat, offset_center[..., None]).squeeze(dim=-1)  # B,3
    new_center += ref_center
    new_angle = ref_rot_angles + offset_rot_angles
    return torch.cat([new_center, new_angle[:, None]], dim=-1)


def remove_transform_points_tensor(points, ref_box_params):
    """

    :param points: B,N,3
    :param ref_box_params: B,4
    :return:
    """
    ref_center = ref_box_params[:, :3]
    ref_rot_angles = ref_box_params[:, -1]

    # transform to object coordinate system defined by the ref_box_params
    rot_mat = rotz_batch_tensor(-ref_rot_angles)  # B,3,3
    points -= ref_center[:, None, :]  # B,N,3
    points = torch.matmul(points, rot_mat.transpose(1, 2))
    return points


def np_to_torch_tensor(data, device=None):
    return torch.tensor(data, device=device).unsqueeze(dim=0)



================================================
FILE: datasets/sampler.py
================================================
# Created by zenn at 2021/4/27

import numpy as np
import torch
from easydict import EasyDict
from nuscenes.utils import geometry_utils

import datasets.points_utils as points_utils
from datasets.searchspace import KalmanFiltering


def no_processing(data, *args):
    return data


def siamese_processing(data, config, template_transform=None, search_transform=None):
    """

    :param data:
    :param config: {model_bb_scale,model_bb_offset,search_bb_scale, search_bb_offset}
    :return:
    """
    first_frame = data['first_frame']
    template_frame = data['template_frame']
    search_frame = data['search_frame']
    candidate_id = data['candidate_id']
    first_pc, first_box = first_frame['pc'], first_frame['3d_bbox']
    template_pc, template_box = template_frame['pc'], template_frame['3d_bbox']
    search_pc, search_box = search_frame['pc'], search_frame['3d_bbox']
    if template_transform is not None:
        template_pc, template_box = template_transform(template_pc, template_box)
        first_pc, first_box = template_transform(first_pc, first_box)
    if search_transform is not None:
        search_pc, search_box = search_transform(search_pc, search_box)
    # generating template. Merging the object from previous and the first frames.
    if candidate_id == 0:
        samplegt_offsets = np.zeros(3)
    else:
        samplegt_offsets = np.random.uniform(low=-0.3, high=0.3, size=3)
        samplegt_offsets[2] = samplegt_offsets[2] * (5 if config.degrees else np.deg2rad(5))
    template_box = points_utils.getOffsetBB(template_box, samplegt_offsets, limit_box=config.data_limit_box,
                                            degrees=config.degrees)
    model_pc, model_box = points_utils.getModel([first_pc, template_pc], [first_box, template_box],
                                                scale=config.model_bb_scale, offset=config.model_bb_offset)

    assert model_pc.nbr_points() > 20, 'not enough template points'

    # generating search area. Use the current gt box to select the nearby region as the search area.

    if candidate_id == 0 and config.num_candidates > 1:
        sample_offset = np.zeros(3)
    else:
        gaussian = KalmanFiltering(bnd=[1, 1, (5 if config.degrees else np.deg2rad(5))])
        sample_offset = gaussian.sample(1)[0]
    sample_bb = points_utils.getOffsetBB(search_box, sample_offset, limit_box=config.data_limit_box,
                                         degrees=config.degrees)
    search_pc_crop = points_utils.generate_subwindow(search_pc, sample_bb,
                                                     scale=config.search_bb_scale, offset=config.search_bb_offset)
    assert search_pc_crop.nbr_points() > 20, 'not enough search points'
    search_box = points_utils.transform_box(search_box, sample_bb)
    seg_label = points_utils.get_in_box_mask(search_pc_crop, search_box).astype(int)
    search_bbox_reg = [search_box.center[0], search_box.center[1], search_box.center[2], -sample_offset[2]]

    template_points, idx_t = points_utils.regularize_pc(model_pc.points.T, config.template_size)
    search_points, idx_s = points_utils.regularize_pc(search_pc_crop.points.T, config.search_size)
    seg_label = seg_label[idx_s]
    data_dict = {
        'template_points': template_points.astype('float32'),
        'search_points': search_points.astype('float32'),
        'box_label': np.array(search_bbox_reg).astype('float32'),
        'bbox_size': search_box.wlh,
        'seg_label': seg_label.astype('float32'),
    }
    if getattr(config, 'box_aware', False):
        template_bc = points_utils.get_point_to_box_distance(template_points, model_box)
        search_bc = points_utils.get_point_to_box_distance(search_points, search_box)
        data_dict.update({'points2cc_dist_t': template_bc.astype('float32'),
                          'points2cc_dist_s': search_bc.astype('float32'), })
    return data_dict


def motion_processing(data, config, template_transform=None, search_transform=None):
    """

    :param data:
    :param config: {model_bb_scale,model_bb_offset,search_bb_scale, search_bb_offset}
    :return:
    point_sample_size
    bb_scale
    bb_offset
    """
    prev_frame = data['prev_frame']
    this_frame = data['this_frame']
    candidate_id = data['candidate_id']
    prev_pc, prev_box = prev_frame['pc'], prev_frame['3d_bbox']
    this_pc, this_box = this_frame['pc'], this_frame['3d_bbox']

    num_points_in_prev_box = geometry_utils.points_in_box(prev_box, prev_pc.points).sum()
    assert num_points_in_prev_box > 10, 'not enough target points'

    if template_transform is not None:
        prev_pc, prev_box = template_transform(prev_pc, prev_box)
    if search_transform is not None:
        this_pc, this_box = search_transform(this_pc, this_box)

    if candidate_id == 0:
        sample_offsets = np.zeros(3)
    else:
        sample_offsets = np.random.uniform(low=-0.3, high=0.3, size=3)
        sample_offsets[2] = sample_offsets[2] * (5 if config.degrees else np.deg2rad(5))
    ref_box = points_utils.getOffsetBB(prev_box, sample_offsets, limit_box=config.data_limit_box,
                                       degrees=config.degrees)
    prev_frame_pc = points_utils.generate_subwindow(prev_pc, ref_box,
                                                    scale=config.bb_scale,
                                                    offset=config.bb_offset)

    this_frame_pc = points_utils.generate_subwindow(this_pc, ref_box,
                                                    scale=config.bb_scale,
                                                    offset=config.bb_offset)
    assert this_frame_pc.nbr_points() > 20, 'not enough search points'

    this_box = points_utils.transform_box(this_box, ref_box)
    prev_box = points_utils.transform_box(prev_box, ref_box)
    ref_box = points_utils.transform_box(ref_box, ref_box)
    motion_box = points_utils.transform_box(this_box, prev_box)

    prev_points, idx_prev = points_utils.regularize_pc(prev_frame_pc.points.T, config.point_sample_size)
    this_points, idx_this = points_utils.regularize_pc(this_frame_pc.points.T, config.point_sample_size)

    seg_label_this = geometry_utils.points_in_box(this_box, this_points.T, 1.25).astype(int)
    seg_label_prev = geometry_utils.points_in_box(prev_box, prev_points.T, 1.25).astype(int)
    seg_mask_prev = geometry_utils.points_in_box(ref_box, prev_points.T, 1.25).astype(float)
    if candidate_id != 0:
        # Here we use 0.2/0.8 instead of 0/1 to indicate that the previous box is not GT.
        # When boxcloud is used, the actual value of prior-targetness mask doesn't really matter.
        seg_mask_prev[seg_mask_prev == 0] = 0.2
        seg_mask_prev[seg_mask_prev == 1] = 0.8
    seg_mask_this = np.full(seg_mask_prev.shape, fill_value=0.5)

    timestamp_prev = np.full((config.point_sample_size, 1), fill_value=0)
    timestamp_this = np.full((config.point_sample_size, 1), fill_value=0.1)

    prev_points = np.concatenate([prev_points, timestamp_prev, seg_mask_prev[:, None]], axis=-1)
    this_points = np.concatenate([this_points, timestamp_this, seg_mask_this[:, None]], axis=-1)

    stack_points = np.concatenate([prev_points, this_points], axis=0)
    stack_seg_label = np.hstack([seg_label_prev, seg_label_this])
    theta_this = this_box.orientation.degrees * this_box.orientation.axis[-1] if config.degrees else \
        this_box.orientation.radians * this_box.orientation.axis[-1]
    box_label = np.append(this_box.center, theta_this).astype('float32')
    theta_prev = prev_box.orientation.degrees * prev_box.orientation.axis[-1] if config.degrees else \
        prev_box.orientation.radians * prev_box.orientation.axis[-1]
    box_label_prev = np.append(prev_box.center, theta_prev).astype('float32')
    theta_motion = motion_box.orientation.degrees * motion_box.orientation.axis[-1] if config.degrees else \
        motion_box.orientation.radians * motion_box.orientation.axis[-1]
    motion_label = np.append(motion_box.center, theta_motion).astype('float32')

    motion_state_label = np.sqrt(np.sum((this_box.center - prev_box.center) ** 2)) > config.motion_threshold

    data_dict = {
        'points': stack_points.astype('float32'),
        'box_label': box_label,
        'box_label_prev': box_label_prev,
        'motion_label': motion_label,
        'motion_state_label': motion_state_label.astype('int'),
        'bbox_size': this_box.wlh,
        'seg_label': stack_seg_label.astype('int'),
    }

    if getattr(config, 'box_aware', False):
        prev_bc = points_utils.get_point_to_box_distance(stack_points[:config.point_sample_size, :3], prev_box)
        this_bc = points_utils.get_point_to_box_distance(stack_points[config.point_sample_size:, :3], this_box)
        candidate_bc_prev = points_utils.get_point_to_box_distance(stack_points[:config.point_sample_size, :3], ref_box)
        candidate_bc_this = np.zeros_like(candidate_bc_prev)
        candidate_bc = np.concatenate([candidate_bc_prev, candidate_bc_this], axis=0)

        data_dict.update({'prev_bc': prev_bc.astype('float32'),
                          'this_bc': this_bc.astype('float32'),
                          'candidate_bc': candidate_bc.astype('float32')})
    return data_dict


class PointTrackingSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, random_sample, sample_per_epoch=10000, processing=siamese_processing, config=None,
                 **kwargs):
        if config is None:
            config = EasyDict(kwargs)
        self.sample_per_epoch = sample_per_epoch
        self.dataset = dataset
        self.processing = processing
        self.config = config
        self.random_sample = random_sample
        self.num_candidates = getattr(config, 'num_candidates', 1)
        if getattr(self.config, "use_augmentation", False):
            print('using augmentation')
            self.transform = points_utils.apply_augmentation
        else:
            self.transform = None
        if not self.random_sample:
            num_frames_total = 0
            self.tracklet_start_ids = [num_frames_total]
            for i in range(dataset.get_num_tracklets()):
                num_frames_total += dataset.get_num_frames_tracklet(i)
                self.tracklet_start_ids.append(num_frames_total)

    def get_anno_index(self, index):
        return index // self.num_candidates

    def get_candidate_index(self, index):
        return index % self.num_candidates

    def __len__(self):
        if self.random_sample:
            return self.sample_per_epoch * self.num_candidates
        else:
            return self.dataset.get_num_frames_total() * self.num_candidates

    def __getitem__(self, index):
        anno_id = self.get_anno_index(index)
        candidate_id = self.get_candidate_index(index)
        try:
            if self.random_sample:
                tracklet_id = torch.randint(0, self.dataset.get_num_tracklets(), size=(1,)).item()
                tracklet_annos = self.dataset.tracklet_anno_list[tracklet_id]
                frame_ids = [0] + points_utils.random_choice(num_samples=2, size=len(tracklet_annos)).tolist()
            else:
                for i in range(0, self.dataset.get_num_tracklets()):
                    if self.tracklet_start_ids[i] <= anno_id < self.tracklet_start_ids[i + 1]:
                        tracklet_id = i
                        this_frame_id = anno_id - self.tracklet_start_ids[i]
                        prev_frame_id = max(this_frame_id - 1, 0)
                        frame_ids = (0, prev_frame_id, this_frame_id)
            first_frame, template_frame, search_frame = self.dataset.get_frames(tracklet_id, frame_ids=frame_ids)
            data = {"first_frame": first_frame,
                    "template_frame": template_frame,
                    "search_frame": search_frame,
                    "candidate_id": candidate_id}

            return self.processing(data, self.config,
                                   template_transform=None,
                                   search_transform=self.transform)
        except AssertionError:
            return self[torch.randint(0, len(self), size=(1,)).item()]


class TestTrackingSampler(torch.utils.data.Dataset):
    def __init__(self, dataset, config=None, **kwargs):
        if config is None:
            config = EasyDict(kwargs)
        self.dataset = dataset
        self.config = config

    def __len__(self):
        return self.dataset.get_num_tracklets()

    def __getitem__(self, index):
        tracklet_annos = self.dataset.tracklet_anno_list[index]
        frame_ids = list(range(len(tracklet_annos)))
        return self.dataset.get_frames(index, frame_ids)


class MotionTrackingSampler(PointTrackingSampler):
    def __init__(self, dataset, config=None, **kwargs):
        super().__init__(dataset, random_sample=False, config=config, **kwargs)
        self.processing = motion_processing

    def __getitem__(self, index):
        anno_id = self.get_anno_index(index)
        candidate_id = self.get_candidate_index(index)
        try:

            for i in range(0, self.dataset.get_num_tracklets()):
                if self.tracklet_start_ids[i] <= anno_id < self.tracklet_start_ids[i + 1]:
                    tracklet_id = i
                    this_frame_id = anno_id - self.tracklet_start_ids[i]
                    prev_frame_id = max(this_frame_id - 1, 0)
                    frame_ids = (0, prev_frame_id, this_frame_id)
            first_frame, prev_frame, this_frame = self.dataset.get_frames(tracklet_id, frame_ids=frame_ids)
            data = {
                "first_frame": first_frame,
                "prev_frame": prev_frame,
                "this_frame": this_frame,
                "candidate_id": candidate_id}
            return self.processing(data, self.config,
                                   template_transform=self.transform,
                                   search_transform=self.transform)
        except AssertionError:
            return self[torch.randint(0, len(self), size=(1,)).item()]


================================================
FILE: datasets/searchspace.py
================================================
import numpy as np
from pomegranate import MultivariateGaussianDistribution, GeneralMixtureModel
import logging


class SearchSpace(object):

    def reset(self):
        raise NotImplementedError

    def sample(self):
        raise NotImplementedError

    def addData(self, data, score):
        return


class ExhaustiveSearch(SearchSpace):

    def __init__(self,
                 search_space=[[-3.0, 3.0], [-3.0, 3.0], [-10.0, 10.0]],
                 search_dims=[7, 7, 3]):

        x_space = np.linspace(
            search_space[0][0], search_space[0][1],
            search_dims[0])

        y_space = np.linspace(
            search_space[1][0], search_space[1][1],
            search_dims[1])

        a_space = np.linspace(
            search_space[2][0], search_space[2][1],
            search_dims[2])

        X, Y, A = np.meshgrid(x_space, y_space, a_space)  # create mesh grid

        self.search_grid = np.array([X.flatten(), Y.flatten(), A.flatten()]).T

        self.reset()

    def reset(self):
        return

    def sample(self, n=0):
        return self.search_grid


class ParticleFiltering(SearchSpace):
    def __init__(self, bnd=[1, 1, 10]):
        self.bnd = bnd
        self.reset()

    def sample(self, n=10):
        samples = []
        for i in range(n):
            if len(self.data) > 0:
                i_mean = np.random.choice(
                    list(range(len(self.data))),
                    p=self.score / np.linalg.norm(self.score, ord=1))
                sample = np.random.multivariate_normal(
                    mean=self.data[i_mean], cov=np.diag(np.array(self.bnd)))
            else:
                sample = np.random.multivariate_normal(
                    mean=np.zeros(len(self.bnd)),
                    cov=np.diag(np.array(self.bnd) * 3))

            samples.append(sample)
        return np.array(samples)

    def addData(self, data, score):
        score = score.clip(min=1e-5)  # prevent sum=0 in case of bad scores
        self.data = data
        self.score = score

    def reset(self):
        if len(self.bnd) == 2:
            self.data = np.array([[], []]).T
        else:
            self.data = np.array([[], [], []]).T
        self.score = np.ones(np.shape(self.data)[0])
        self.score = self.score / np.linalg.norm(self.score, ord=1)


class KalmanFiltering(SearchSpace):
    def __init__(self, bnd=[1, 1, 10]):
        self.bnd = bnd
        self.reset()

    def sample(self, n=10):
        return np.random.multivariate_normal(self.mean, self.cov, size=n)

    def addData(self, data, score):
        score = score.clip(min=1e-5)  # prevent sum=0 in case of bad scores
        self.data = np.concatenate((self.data, data))
        self.score = np.concatenate((self.score, score))
        self.mean = np.average(self.data, weights=self.score, axis=0)
        self.cov = np.cov(self.data.T, ddof=0, aweights=self.score)

    def reset(self):
        self.mean = np.zeros(len(self.bnd))
        self.cov = np.diag(self.bnd)
        if len(self.bnd) == 2:
            self.data = np.array([[], []]).T
        else:
            self.data = np.array([[], [], []]).T
        self.score = np.array([])


class GaussianMixtureModel(SearchSpace):

    def __init__(self, n_comp=5, dim=3):
        self.dim = dim
        self.reset(n_comp)

    def sample(self, n=10):
        try:
            X1 = np.stack(self.model.sample(int(np.round(0.8 * n))))
            if self.dim == 2:
                mean = np.mean(X1, axis=0)
                std = np.diag([1.0, 1.0])
                gmm = MultivariateGaussianDistribution(mean, std)
                X2 = np.stack(gmm.sample(int(np.round(0.1 * n))))

                mean = np.mean(X1, axis=0)
                std = np.diag([1e-3, 1e-3])
                gmm = MultivariateGaussianDistribution(mean, std)
                X3 = np.stack(gmm.sample(int(np.round(0.1 * n))))

            else:
                mean = np.mean(X1, axis=0)
                std = np.diag([1.0, 1.0, 1e-3])
                gmm = MultivariateGaussianDistribution(mean, std)
                X2 = np.stack(gmm.sample(int(np.round(0.1 * n))))

                mean = np.mean(X1, axis=0)
                std = np.diag([1e-3, 1e-3, 10.0])
                gmm = MultivariateGaussianDistribution(mean, std)
                X3 = np.stack(gmm.sample(int(np.round(0.1 * n))))

            X = np.concatenate((X1, X2, X3))

        except ValueError:
            print("exception caught on sampling")
            if self.dim == 2:
                mean = np.zeros(self.dim)
                std = np.diag([1.0, 1.0])
                gmm = MultivariateGaussianDistribution(mean, std)
                X = gmm.sample(int(n))
            else:
                mean = np.zeros(self.dim)
                std = np.diag([1.0, 1.0, 5.0])
                gmm = MultivariateGaussianDistribution(mean, std)
                X = gmm.sample(int(n))
        return X

    def addData(self, data, score):
        score = score.clip(min=1e-5)
        self.data = data
        self.score = score

        score_normed = self.score / np.linalg.norm(self.score, ord=1)
        try:
            model = GeneralMixtureModel.from_samples(
                MultivariateGaussianDistribution,
                n_components=self.n_comp,
                X=self.data,
                weights=score_normed)
            self.model = model
        except:
            logging.info("catched an exception")

    def reset(self, n_comp=5):
        self.n_comp = n_comp

        if self.dim == 2:
            self.data = np.array([[], []]).T
        else:
            self.data = np.array([[], [], []]).T
        self.score = np.ones(np.shape(self.data)[0])
        self.score = self.score / np.linalg.norm(self.score, ord=1)
        if self.dim == 2:
            self.model = MultivariateGaussianDistribution(
                np.zeros(self.dim), np.diag([1.0, 1.0]))
        else:
            self.model = MultivariateGaussianDistribution(
                np.zeros(self.dim), np.diag([1.0, 1.0, 5.0]))


================================================
FILE: datasets/utils.py
================================================
#!/usr/bin/env python
# encoding: utf-8
'''
@author: Xu Yan
@file: utils.py
@time: 2021/10/21 21:45
'''
import numpy as np

def roty(t):
    """Rotation about the y-axis."""
    c = np.cos(t)
    s = np.sin(t)
    return np.array([[c, -s,  0],
                     [s,  c,  0],
                     [0,  0,  1]])

def get_3d_box(box_size, heading_angle, center):
    ''' box_size is array(l,w,h), heading_angle is radius clockwise from pos x axis, center is xyz of box center
        output (8,3) array for 3D box cornders
        Similar to utils/compute_orientation_3d
    '''
    R = roty(heading_angle)
    l,w,h = box_size
    # x_corners = [l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2]
    # y_corners = [h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2]
    # z_corners = [w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2]
    x_corners = [l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2]
    y_corners = [w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2]
    z_corners = [h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2]
    corners_3d = np.dot(R, np.vstack([x_corners,y_corners,z_corners]))
    corners_3d[0,:] = corners_3d[0,:] + center[0]
    corners_3d[1,:] = corners_3d[1,:] + center[1]
    corners_3d[2,:] = corners_3d[2,:] + center[2]
    corners_3d = np.transpose(corners_3d)
    return corners_3d


def write_ply(verts, colors, indices, output_file):
    if colors is None:
        colors = np.zeros_like(verts)
    if indices is None:
        indices = []

    file = open(output_file, 'w')
    file.write('ply \n')
    file.write('format ascii 1.0\n')
    file.write('element vertex {:d}\n'.format(len(verts)))
    file.write('property float x\n')
    file.write('property float y\n')
    file.write('property float z\n')
    file.write('property uchar red\n')
    file.write('property uchar green\n')
    file.write('property uchar blue\n')
    file.write('element face {:d}\n'.format(len(indices)))
    file.write('property list uchar uint vertex_indices\n')
    file.write('end_header\n')
    for vert, color in zip(verts, colors):
        file.write("{:f} {:f} {:f} {:d} {:d} {:d}\n".format(vert[0], vert[1], vert[2], int(color[0] * 255),
                                                            int(color[1] * 255), int(color[2] * 255)))
    for ind in indices:
        file.write('3 {:d} {:d} {:d}\n'.format(ind[0], ind[1], ind[2]))
    file.close()


def box2obj(box, objname):
    corners = box.corners().T
    with open(objname, 'w') as f:
        for corner in corners:
            f.write('v %f %f %f\n' % (corner[0], corner[1], corner[2]))
        f.write('f %d %d %d %d\n' % (1, 2, 3, 4))
        f.write('f %d %d %d %d\n' % (5, 6, 7, 8))
        f.write('f %d %d %d %d\n' % (1, 5, 8, 4))
        f.write('f %d %d %d %d\n' % (2, 6, 7, 3))
        f.write('f %d %d %d %d\n' % (1, 2, 6, 5))
        f.write('f %d %d %d %d\n' % (4, 3, 7, 8))


def write_bbox(corners, mode, output_file):
    """
    bbox: (cx, cy, cz, lx, ly, lz, r), center and length in three axis, the last is the rotation
    output_file: string
    """

    def create_cylinder_mesh(radius, p0, p1, stacks=10, slices=10):

        import math

        def compute_length_vec3(vec3):
            return math.sqrt(vec3[0] * vec3[0] + vec3[1] * vec3[1] + vec3[2] * vec3[2])

        def rotation(axis, angle):
            rot = np.eye(4)
            c = np.cos(-angle)
            s = np.sin(-angle)
            t = 1.0 - c
            axis /= compute_length_vec3(axis)
            x = axis[0]
            y = axis[1]
            z = axis[2]
            rot[0, 0] = 1 + t * (x * x - 1)
            rot[0, 1] = z * s + t * x * y
            rot[0, 2] = -y * s + t * x * z
            rot[1, 0] = -z * s + t * x * y
            rot[1, 1] = 1 + t * (y * y - 1)
            rot[1, 2] = x * s + t * y * z
            rot[2, 0] = y * s + t * x * z
            rot[2, 1] = -x * s + t * y * z
            rot[2, 2] = 1 + t * (z * z - 1)
            return rot

        verts = []
        indices = []
        diff = (p1 - p0).astype(np.float32)
        height = compute_length_vec3(diff)
        for i in range(stacks + 1):
            for i2 in range(slices):
                theta = i2 * 2.0 * math.pi / slices
                pos = np.array([radius * math.cos(theta), radius * math.sin(theta), height * i / stacks])
                verts.append(pos)
        for i in range(stacks):
            for i2 in range(slices):
                i2p1 = math.fmod(i2 + 1, slices)
                indices.append(np.array([(i + 1) * slices + i2, i * slices + i2, i * slices + i2p1], dtype=np.uint32))
                indices.append(
                    np.array([(i + 1) * slices + i2, i * slices + i2p1, (i + 1) * slices + i2p1], dtype=np.uint32))
        transform = np.eye(4)
        va = np.array([0, 0, 1], dtype=np.float32)
        vb = diff
        vb /= compute_length_vec3(vb)
        axis = np.cross(vb, va)
        angle = np.arccos(np.clip(np.dot(va, vb), -1, 1))
        if angle != 0:
            if compute_length_vec3(axis) == 0:
                dotx = va[0]
                if (math.fabs(dotx) != 1.0):
                    axis = np.array([1, 0, 0]) - dotx * va
                else:
                    axis = np.array([0, 1, 0]) - va[1] * va
                axis /= compute_length_vec3(axis)
            transform = rotation(axis, -angle)
        transform[:3, 3] += p0
        verts = [np.dot(transform, np.array([v[0], v[1], v[2], 1.0])) for v in verts]
        verts = [np.array([v[0], v[1], v[2]]) / v[3] for v in verts]

        return verts, indices

    def get_bbox_edges(bbox_min, bbox_max):
        def get_bbox_verts(bbox_min, bbox_max):
            verts = [
                np.array([bbox_min[0], bbox_min[1], bbox_min[2]]),
                np.array([bbox_max[0], bbox_min[1], bbox_min[2]]),
                np.array([bbox_max[0], bbox_max[1], bbox_min[2]]),
                np.array([bbox_min[0], bbox_max[1], bbox_min[2]]),

                np.array([bbox_min[0], bbox_min[1], bbox_max[2]]),
                np.array([bbox_max[0], bbox_min[1], bbox_max[2]]),
                np.array([bbox_max[0], bbox_max[1], bbox_max[2]]),
                np.array([bbox_min[0], bbox_max[1], bbox_max[2]])
            ]
            return verts

        box_verts = get_bbox_verts(bbox_min, bbox_max)
        edges = [
            (box_verts[0], box_verts[1]),
            (box_verts[1], box_verts[2]),
            (box_verts[2], box_verts[3]),
            (box_verts[3], box_verts[0]),

            (box_verts[4], box_verts[5]),
            (box_verts[5], box_verts[6]),
            (box_verts[6], box_verts[7]),
            (box_verts[7], box_verts[4]),

            (box_verts[0], box_verts[4]),
            (box_verts[1], box_verts[5]),
            (box_verts[2], box_verts[6]),
            (box_verts[3], box_verts[7])
        ]
        return edges

    radius = 0.03
    offset = [0, 0, 0]
    verts = []
    indices = []
    colors = []

    box_min = np.min(corners, axis=0)
    box_max = np.max(corners, axis=0)
    palette = {
        0: [0, 255, 0],  # gt
        1: [0, 0, 255]  # pred
    }
    chosen_color = palette[mode]
    edges = get_bbox_edges(box_min, box_max)
    for k in range(len(edges)):
        cyl_verts, cyl_ind = create_cylinder_mesh(radius, edges[k][0], edges[k][1])
        cur_num_verts = len(verts)
        cyl_color = [[c / 255 for c in chosen_color] for _ in cyl_verts]
        cyl_verts = [x + offset for x in cyl_verts]
        cyl_ind = [x + cur_num_verts for x in cyl_ind]
        verts.extend(cyl_verts)
        indices.extend(cyl_ind)
        colors.extend(cyl_color)

    write_ply(verts, colors, indices, output_file)


def write_obj(points, file, rgb=False):
    fout = open('%s.obj' % file, 'w')
    for i in range(points.shape[0]):
        if not rgb:
            fout.write('v %f %f %f %d %d %d\n' % (
                points[i, 0], points[i, 1], points[i, 2], 255, 255, 0))
        else:
            fout.write('v %f %f %f %d %d %d\n' % (
                points[i, 0], points[i, 1], points[i, 2], points[i, -3] * 255, points[i, -2] * 255,
                points[i, -1] * 255))


================================================
FILE: datasets/waymo_data.py
================================================
# Created by Xu Yan at 2021/10/17

import copy
import random

from torch.utils.data import Dataset
from datasets.data_classes import PointCloud, Box
from pyquaternion import Quaternion
import numpy as np
import pandas as pd
import os
import warnings
import pickle
from functools import reduce
from tqdm import tqdm
from datasets.generate_waymo_sot import generate_waymo_data
from collections import defaultdict
from datasets import points_utils, base_dataset


class WaymoDataset(base_dataset.BaseDataset):
    def __init__(self, path, split, category_name="VEHICLE", **kwargs):
        super().__init__(path, split, category_name, **kwargs)
        self.Waymo_Folder = path
        self.category_name = category_name
        self.Waymo_velo = os.path.join(self.Waymo_Folder, split, "velodyne")
        self.Waymo_label = os.path.join(self.Waymo_Folder, split, "label_02")
        self.Waymo_calib = os.path.join(self.Waymo_Folder, split, "calib")
        self.velos = defaultdict(dict)
        self.calibs = {}

        self.split = self.split.lower()
        self.category_name = self.category_name.lower()
        self.split = 'val' if self.split == 'test' else self.split
        assert self.split in ['train', 'val']
        assert self.category_name in ['vehicle', 'pedestrian', 'cyclist']

        self.tiny = kwargs.get('tiny', False)
        self.tracklet_anno_list, self.tracklet_len_list = self._build_tracklet_anno()
        if self.tiny:
            self.tracklet_anno_list = self.tracklet_anno_list[:100]
            self.tracklet_len_list = self.tracklet_len_list[:100]

        self.preload_offset = kwargs.get('preload_offset', 10)
        if self.preloading:
            self.training_samples = self._load_data()

    def _load_data(self):
        print('preloading data into memory')
        if self.tiny:
            preload_data_path = os.path.join(self.Waymo_Folder,
                                             f"preload_{self.split}_{self.category_name}_{self.preload_offset}_tiny.dat")
        else:
            preload_data_path = os.path.join(self.Waymo_Folder,
                                             f"preload_{self.split}_{self.category_name}_{self.preload_offset}.dat")

        print(preload_data_path)

        if os.path.isfile(preload_data_path):
            print(f'loading from saved file {preload_data_path}.')
            with open(preload_data_path, 'rb') as f:
                training_samples = pickle.load(f)
        else:
            print('reading from annos')
            training_samples = []
            for i in tqdm(range(len(self.tracklet_anno_list)), total=len(self.tracklet_anno_list)):
                frames = []
                for anno in self.tracklet_anno_list[i]:
                    frames.append(self._get_frame_from_anno(anno, i))

                training_samples.append(frames)
            with open(preload_data_path, 'wb') as f:
                print(f'saving loaded data to {preload_data_path}')
                pickle.dump(training_samples, f)
        return training_samples

    def get_num_scenes(self):
        return len(self.scene_list)

    def get_num_tracklets(self):
        return len(self.tracklet_anno_list)

    def get_num_frames_total(self):
        return sum(self.tracklet_len_list)

    def get_num_frames_tracklet(self, tracklet_id):
        return self.tracklet_len_list[tracklet_id]

    def _build_tracklet_anno(self):
        preload_data_path = os.path.join(self.Waymo_Folder,
                                         f"sot_infos_{self.category_name.lower()}_{self.split.lower()}.pkl")
        if not os.path.exists(preload_data_path):
            print('Prepare %s' % preload_data_path)
            generate_waymo_data(self.Waymo_Folder, self.category_name, self.split)

        with open(preload_data_path, 'rb') as f:
            infos = pickle.load(f)

        list_of_tracklet_anno = []
        list_of_tracklet_len = []

        for scene in list(infos.keys()):
            anno = infos[scene]
            list_of_tracklet_anno.append(anno)
            list_of_tracklet_len.append(len(anno))

        return list_of_tracklet_anno, list_of_tracklet_len

    def get_frames(self, seq_id, frame_ids):
        if self.preloading:
            frames = [self.training_samples[seq_id][f_id] for f_id in frame_ids]
        else:
            seq_annos = self.tracklet_anno_list[seq_id]
            frames = [self._get_frame_from_anno(seq_annos[f_id]) for f_id in frame_ids]

        return frames

    def _get_frame_from_anno(self, anno, track_id=None, check=False):
        '''
        'box': np.array([box.center_x, box.center_y, box.center_z,
                         box.length, box.width, box.height, ref_velocity[0],
                         ref_velocity[1], box.heading], dtype=np.float32),
        '''
        sample_data_lidar = anno['PC']
        gt_boxes = anno['Box']

        with open(sample_data_lidar, 'rb') as f:
            pc_info = pickle.load(f)

        pointcloud = pc_info['lidars']['points_xyz'].transpose((1, 0))

        with open(sample_data_lidar.replace('lidar', 'annos'), 'rb') as f:
            ref_obj = pickle.load(f)

        ref_pose = np.reshape(ref_obj['veh_to_global'], [4, 4])
        global_from_car, _ = self.veh_pos_to_transform(ref_pose)
        nbr_points = pointcloud.shape[1]
        pointcloud[:3, :] = global_from_car.dot(
            np.vstack((pointcloud[:3, :], np.ones(nbr_points)))
        )[:3, :]

        # transform from Waymo to KITTI coordinate
        # Waymo: x, y, z, length, width, height, rotation from positive x axis clockwisely
        # KITTI: x, y, z, width, length, height, rotation from negative y axis counterclockwisely
        gt_boxes[[3, 4]] = gt_boxes[[4, 3]]

        pc = PointCloud(pointcloud)
        bb = Box(gt_boxes[0:3], gt_boxes[3:6], Quaternion(axis=[0, 0, 1], radians=-gt_boxes[-1]),
                 velocity=gt_boxes[6:9], name=anno['Class'])
        bb.rotate(Quaternion(matrix=global_from_car))
        bb.translate(global_from_car[:3, -1])
        if self.preload_offset > 0:
            pc = points_utils.crop_pc_axis_aligned(pc, bb, offset=self.preload_offset)

        if check:
            from datasets.utils import write_bbox, write_obj, get_3d_box, box2obj
            print('check', pc_info['frame_id'])
            path = 'visual_%s_track%d/' % (pc_info['scene_name'], track_id)
            os.makedirs(path, exist_ok=True)
            if pc_info['frame_id'] % 50 == 0:
                write_obj(pc.points.transpose((1, 0)), path + 'frames_%d' % pc_info['frame_id'])
                # write_bbox(get_3d_box(bb.wlh, bb.orientation.radians * bb.orientation.axis[-1], bb.center), 0, path + 'box_%d.ply' % pc_info['frame_id'])
                box2obj(bb, path + 'box_%d.obj' % pc_info['frame_id'])
            print(path + 'box_%d.obj' % pc_info['frame_id'])
            # exit()

        return {"pc": pc, "3d_bbox": bb, 'meta': anno}

    @staticmethod
    def veh_pos_to_transform(veh_pos):
        def transform_matrix(translation: np.ndarray = np.array([0, 0, 0]),
                             rotation: Quaternion = Quaternion([1, 0, 0, 0]),
                             inverse: bool = False) -> np.ndarray:
            """
            Convert pose to transformation matrix.
            :param translation: <np.float32: 3>. Translation in x, y, z.
            :param rotation: Rotation in quaternions (w ri rj rk).
            :param inverse: Whether to compute inverse transform matrix.
            :return: <np.float32: 4, 4>. Transformation matrix.
            """
            tm = np.eye(4)

            if inverse:
                rot_inv = rotation.rotation_matrix.T
                trans = np.transpose(-np.array(translation))
                tm[:3, :3] = rot_inv
                tm[:3, 3] = rot_inv.dot(trans)
            else:
                tm[:3, :3] = rotation.rotation_matrix
                tm[:3, 3] = np.transpose(np.array(translation))

            return tm

        "convert vehicle pose to two transformation matrix"
        rotation = veh_pos[:3, :3]
        tran = veh_pos[:3, 3]

        global_from_car = transform_matrix(
            tran, Quaternion(matrix=rotation), inverse=False
        )

        car_from_global = transform_matrix(
            tran, Quaternion(matrix=rotation), inverse=True
        )

        return global_from_car, car_from_global




================================================
FILE: main.py
================================================
"""
main.py
Created by zenn at 2021/7/18 15:08
"""
import pytorch_lightning as pl
import argparse

import pytorch_lightning.utilities.distributed
import torch
import yaml
from easydict import EasyDict
import os

from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

from datasets import get_dataset
from models import get_model


# os.environ["NCCL_DEBUG"] = "INFO"

def load_yaml(file_name):
    with open(file_name, 'r') as f:
        try:
            config = yaml.load(f, Loader=yaml.FullLoader)
        except:
            config = yaml.load(f)
    return config


def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=100, help='input batch size')
    parser.add_argument('--epoch', type=int, default=60, help='number of epochs')
    parser.add_argument('--save_top_k', type=int, default=-1, help='save top k checkpoints')
    parser.add_argument('--check_val_every_n_epoch', type=int, default=1, help='check_val_every_n_epoch')
    parser.add_argument('--workers', type=int, default=10, help='number of data loading workers')
    parser.add_argument('--cfg', type=str, help='the config_file')
    parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint location')
    parser.add_argument('--log_dir', type=str, default=None, help='log location')
    parser.add_argument('--test', action='store_true', default=False, help='test mode')
    parser.add_argument('--preloading', action='store_true', default=False, help='preload dataset into memory')

    args = parser.parse_args()
    config = load_yaml(args.cfg)
    config.update(vars(args))  # override the configuration using the value in args

    return EasyDict(config)


cfg = parse_config()
env_cp = os.environ.copy()
try:
    node_rank, local_rank, world_size = env_cp['NODE_RANK'], env_cp['LOCAL_RANK'], env_cp['WORLD_SIZE']

    is_in_ddp_subprocess = env_cp['PL_IN_DDP_SUBPROCESS']
    pl_trainer_gpus = env_cp['PL_TRAINER_GPUS']
    print(node_rank, local_rank, world_size, is_in_ddp_subprocess, pl_trainer_gpus)

    if int(local_rank) == int(world_size) - 1:
        print(cfg)
except KeyError:
    pass

# init model
if cfg.checkpoint is None:
    net = get_model(cfg.net_model)(cfg)
else:
    net = get_model(cfg.net_model).load_from_checkpoint(cfg.checkpoint, config=cfg)
if not cfg.test:
    # dataset and dataloader
    train_data = get_dataset(cfg, type=cfg.train_type, split=cfg.train_split)
    val_data = get_dataset(cfg, type='test', split=cfg.val_split)
    train_loader = DataLoader(train_data, batch_size=cfg.batch_size, num_workers=cfg.workers, shuffle=True,drop_last=True,
                              pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=1, num_workers=cfg.workers, collate_fn=lambda x: x, pin_memory=True)
    checkpoint_callback = ModelCheckpoint(monitor='precision/test', mode='max', save_last=True,
                                          save_top_k=cfg.save_top_k)

    # init trainer
    trainer = pl.Trainer(gpus=-1, accelerator='ddp', max_epochs=cfg.epoch, resume_from_checkpoint=cfg.checkpoint,
                         callbacks=[checkpoint_callback], default_root_dir=cfg.log_dir,
                         check_val_every_n_epoch=cfg.check_val_every_n_epoch, num_sanity_val_steps=2,
                         gradient_clip_val=cfg.gradient_clip_val)
    trainer.fit(net, train_loader, val_loader)
else:
    test_data = get_dataset(cfg, type='test', split=cfg.test_split)
    test_loader = DataLoader(test_data, batch_size=1, num_workers=cfg.workers, collate_fn=lambda x: x, pin_memory=True)

    trainer = pl.Trainer(gpus=-1, accelerator='ddp', default_root_dir=cfg.log_dir,
                         resume_from_checkpoint=cfg.checkpoint)
    trainer.test(net, test_loader)


================================================
FILE: models/__init__.py
================================================
""" 
__init__.py
Created by zenn at 2021/7/15 21:40
"""

import importlib
# import pkgutil
# import os
# import inspect
# __all__ = []
# for loader, module_name, is_pkg in pkgutil.walk_packages(os.path.abspath(__file__)):
#     print(loader, module_name, is_pkg)
#     module = loader.find_module(module_name).load_module(module_name)


from models import p2b, bat, m2track


def get_model(name):
    model = globals()[name.lower()].__getattribute__(name.upper())
    return model


================================================
FILE: models/backbone/pointnet.py
================================================
"""
pointnet.py
Created by zenn at 2021/5/9 13:41
"""

import torch
import torch.nn as nn

from pointnet2.utils.pointnet2_modules import PointnetSAModule


class Pointnet_Backbone(nn.Module):
    r"""
        PointNet2 with single-scale grouping
        Semantic segmentation network that uses feature propogation layers

        Parameters
        ----------
        num_classes: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        input_channels: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """

    def __init__(self, use_fps=False, normalize_xyz=False, return_intermediate=False, input_channels=0):
        super(Pointnet_Backbone, self).__init__()
        self.return_intermediate = return_intermediate
        self.SA_modules = nn.ModuleList()
        self.SA_modules.append(
            PointnetSAModule(
                radius=0.3,
                nsample=32,
                mlp=[input_channels, 64, 64, 128],
                use_xyz=True,
                use_fps=use_fps, normalize_xyz=normalize_xyz
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                radius=0.5,
                nsample=32,
                mlp=[128, 128, 128, 256],
                use_xyz=True,
                use_fps=False, normalize_xyz=normalize_xyz
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                radius=0.7,
                nsample=32,
                mlp=[256, 256, 256, 256],
                use_xyz=True,
                use_fps=False, normalize_xyz=normalize_xyz
            )
        )

    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None

        return xyz, features

    def forward(self, pointcloud, numpoints):
        r"""
            Forward pass of the network

            Parameters
            ----------
            pointcloud: Variable(torch.cuda.FloatTensor)
                (B, N, 3 + input_channels) tensor
                Point cloud to run predicts on
                Each point in the point-cloud MUST
                be formated as (x, y, z, features...)
        """
        xyz, features = self._break_up_pc(pointcloud)

        l_xyz, l_features, l_idxs = [xyz], [features], []
        for i in range(len(self.SA_modules)):
            li_xyz, li_features, sample_idxs = self.SA_modules[i](l_xyz[i], l_features[i], numpoints[i], True)
            l_xyz.append(li_xyz)
            l_features.append(li_features)
            l_idxs.append(sample_idxs)
        if self.return_intermediate:
            return l_xyz[1:], l_features[1:], l_idxs[0]
        return l_xyz[-1], l_features[-1], l_idxs[0]


class MiniPointNet(nn.Module):

    def __init__(self, input_channel, per_point_mlp, hidden_mlp, output_size=0):
        """

        :param input_channel: int
        :param per_point_mlp: list
        :param hidden_mlp: list
        :param output_size: int, if output_size <=0, then the final fc will not be used
        """
        super(MiniPointNet, self).__init__()
        seq_per_point = []
        in_channel = input_channel
        for out_channel in per_point_mlp:
            seq_per_point.append(nn.Conv1d(in_channel, out_channel, 1))
            seq_per_point.append(nn.BatchNorm1d(out_channel))
            seq_per_point.append(nn.ReLU())
            in_channel = out_channel
        seq_hidden = []
        for out_channel in hidden_mlp:
            seq_hidden.append(nn.Linear(in_channel, out_channel))
            seq_hidden.append(nn.BatchNorm1d(out_channel))
            seq_hidden.append(nn.ReLU())
            in_channel = out_channel

        # self.per_point_mlp = nn.Sequential(*seq)
        # self.pooling = nn.AdaptiveMaxPool1d(output_size=1)
        # self.hidden_mlp = nn.Sequential(*seq_hidden)

        self.features = nn.Sequential(*seq_per_point,
                                      nn.AdaptiveMaxPool1d(output_size=1),
                                      nn.Flatten(),
                                      *seq_hidden)
        self.output_size = output_size
        if output_size >= 0:
            self.fc = nn.Linear(in_channel, output_size)

    def forward(self, x):
        """

        :param x: B,C,N
        :return: B,output_size
        """

        # x = self.per_point_mlp(x)
        # x = self.pooling(x)
        # x = self.hidden_mlp(x)
        x = self.features(x)
        if self.output_size > 0:
            x = self.fc(x)
        return x


class SegPointNet(nn.Module):

    def __init__(self, input_channel, per_point_mlp1, per_point_mlp2, output_size=0, return_intermediate=False):
        """

        :param input_channel: int
        :param per_point_mlp: list
        :param hidden_mlp: list
        :param output_size: int, if output_size <=0, then the final fc will not be used
        """
        super(SegPointNet, self).__init__()
        self.return_intermediate = return_intermediate
        self.seq_per_point = nn.ModuleList()
        in_channel = input_channel
        for out_channel in per_point_mlp1:
            self.seq_per_point.append(
                nn.Sequential(
                    nn.Conv1d(in_channel, out_channel, 1),
                    nn.BatchNorm1d(out_channel),
                    nn.ReLU()
                ))
            in_channel = out_channel

        self.pool = nn.AdaptiveMaxPool1d(output_size=1)

        self.seq_per_point2 = nn.ModuleList()
        in_channel = in_channel + per_point_mlp1[1]
        for out_channel in per_point_mlp2:
            self.seq_per_point2.append(
                nn.Sequential(
                    nn.Conv1d(in_channel, out_channel, 1),
                    nn.BatchNorm1d(out_channel),
                    nn.ReLU()
                ))
            in_channel = out_channel

        self.output_size = output_size
        if output_size >= 0:
            self.fc = nn.Conv1d(in_channel, output_size, 1)

    def forward(self, x):
        """

        :param x: B,C,N
        :return: B,output_size,N
        """
        second_layer_out = None
        for i, mlp in enumerate(self.seq_per_point):
            x = mlp(x)
            if i == 1:
                second_layer_out = x
        pooled_feature = self.pool(x)  # B,C,1
        pooled_feature_expand = pooled_feature.expand_as(x)
        x = torch.cat([second_layer_out, pooled_feature_expand], dim=1)
        for mlp in self.seq_per_point2:
            x = mlp(x)
        if self.output_size > 0:
            x = self.fc(x)
        if self.return_intermediate:
            return x, pooled_feature.squeeze(dim=-1)
        return x



================================================
FILE: models/base_model.py
================================================
""" 
baseModel.py
Created by zenn at 2021/5/9 14:40
"""

import torch
from easydict import EasyDict
import pytorch_lightning as pl
from datasets import points_utils
from utils.metrics import TorchSuccess, TorchPrecision
from utils.metrics import estimateOverlap, estimateAccuracy
import torch.nn.functional as F
import numpy as np
from nuscenes.utils import geometry_utils


class BaseModel(pl.LightningModule):
    def __init__(self, config=None, **kwargs):
        super().__init__()
        if config is None:
            config = EasyDict(kwargs)
        self.config = config

        # testing metrics
        self.prec = TorchPrecision()
        self.success = TorchSuccess()

    def configure_optimizers(self):
        if self.config.optimizer.lower() == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.config.lr, momentum=0.9, weight_decay=self.config.wd)
        else:
            optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd,
                                         betas=(0.5, 0.999), eps=1e-06)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.config.lr_decay_step,
                                                    gamma=self.config.lr_decay_rate)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def compute_loss(self, data, output):
        raise NotImplementedError

    def build_input_dict(self, sequence, frame_id, results_bbs, **kwargs):
        raise NotImplementedError

    def evaluate_one_sample(self, data_dict, ref_box):
        end_points = self(data_dict)

        estimation_box = end_points['estimation_boxes']
        estimation_box_cpu = estimation_box.squeeze(0).detach().cpu().numpy()

        if len(estimation_box.shape) == 3:
            best_box_idx = estimation_box_cpu[:, 4].argmax()
            estimation_box_cpu = estimation_box_cpu[best_box_idx, 0:4]

        candidate_box = points_utils.getOffsetBB(ref_box, estimation_box_cpu, degrees=self.config.degrees,
                                                 use_z=self.config.use_z,
                                                 limit_box=self.config.limit_box)
        return candidate_box

    def evaluate_one_sequence(self, sequence):
        """
        :param sequence: a sequence of annos {"pc": pc, "3d_bbox": bb, 'meta': anno}
        :return:
        """
        ious = []
        distances = []
        results_bbs = []
        for frame_id in range(len(sequence)):  # tracklet
            this_bb = sequence[frame_id]["3d_bbox"]
            if frame_id == 0:
                # the first frame
                results_bbs.append(this_bb)
            else:

                # construct input dict
                data_dict, ref_bb = self.build_input_dict(sequence, frame_id, results_bbs)
                # run the tracker
                candidate_box = self.evaluate_one_sample(data_dict, ref_box=ref_bb)
                results_bbs.append(candidate_box)

            this_overlap = estimateOverlap(this_bb, results_bbs[-1], dim=self.config.IoU_space,
                                           up_axis=self.config.up_axis)
            this_accuracy = estimateAccuracy(this_bb, results_bbs[-1], dim=self.config.IoU_space,
                                             up_axis=self.config.up_axis)
            ious.append(this_overlap)
            distances.append(this_accuracy)
        return ious, distances, results_bbs

    def validation_step(self, batch, batch_idx):
        sequence = batch[0]  # unwrap the batch with batch size = 1
        ious, distances, *_ = self.evaluate_one_sequence(sequence)
        # update metrics
        self.success(torch.tensor(ious, device=self.device))
        self.prec(torch.tensor(distances, device=self.device))
        self.log('success/test', self.success, on_step=True, on_epoch=True)
        self.log('precision/test', self.prec, on_step=True, on_epoch=True)

    def validation_epoch_end(self, outputs):
        self.logger.experiment.add_scalars('metrics/test',
                                           {'success': self.success.compute(),
                                            'precision': self.prec.compute()},
                                           global_step=self.global_step)

    def test_step(self, batch, batch_idx):
        sequence = batch[0]  # unwrap the batch with batch size = 1
        ious, distances, result_bbs = self.evaluate_one_sequence(sequence)
        # update metrics
        self.success(torch.tensor(ious, device=self.device))
        self.prec(torch.tensor(distances, device=self.device))
        self.log('success/test', self.success, on_step=True, on_epoch=True)
        self.log('precision/test', self.prec, on_step=True, on_epoch=True)
        return result_bbs

    def test_epoch_end(self, outputs):
        self.logger.experiment.add_scalars('metrics/test',
                                           {'success': self.success.compute(),
                                            'precision': self.prec.compute()},
                                           global_step=self.global_step)


class MatchingBaseModel(BaseModel):

    def compute_loss(self, data, output):
        """

        :param data: input data
        :param output:
        :return:
        """
        estimation_boxes = output['estimation_boxes']  # B,num_proposal,5
        estimation_cla = output['estimation_cla']  # B,N
        seg_label = data['seg_label']
        box_label = data['box_label']  # B,4
        proposal_center = output["center_xyz"]  # B,num_proposal,3
        vote_xyz = output["vote_xyz"]

        loss_seg = F.binary_cross_entropy_with_logits(estimation_cla, seg_label)

        loss_vote = F.smooth_l1_loss(vote_xyz, box_label[:, None, :3].expand_as(vote_xyz), reduction='none')  # B,N,3
        loss_vote = (loss_vote.mean(2) * seg_label).sum() / (seg_label.sum() + 1e-06)

        dist = torch.sum((proposal_center - box_label[:, None, :3]) ** 2, dim=-1)

        dist = torch.sqrt(dist + 1e-6)  # B, K
        objectness_label = torch.zeros_like(dist, dtype=torch.float)
        objectness_label[dist < 0.3] = 1
        objectness_score = estimation_boxes[:, :, 4]  # B, K
        objectness_mask = torch.zeros_like(objectness_label, dtype=torch.float)
        objectness_mask[dist < 0.3] = 1
        objectness_mask[dist > 0.6] = 1
        loss_objective = F.binary_cross_entropy_with_logits(objectness_score, objectness_label,
                                                            pos_weight=torch.tensor([2.0]).cuda())
        loss_objective = torch.sum(loss_objective * objectness_mask) / (
                torch.sum(objectness_mask) + 1e-6)
        loss_box = F.smooth_l1_loss(estimation_boxes[:, :, :4],
                                    box_label[:, None, :4].expand_as(estimation_boxes[:, :, :4]),
                                    reduction='none')
        loss_box = torch.sum(loss_box.mean(2) * objectness_label) / (objectness_label.sum() + 1e-6)

        return {
            "loss_objective": loss_objective,
            "loss_box": loss_box,
            "loss_seg": loss_seg,
            "loss_vote": loss_vote,
        }

    def generate_template(self, sequence, current_frame_id, results_bbs):
        """
        generate template for evaluating.
        the template can be updated using the previous predictions.
        :param sequence: the list of the whole sequence
        :param current_frame_id:
        :param results_bbs: predicted box for previous frames
        :return:
        """
        first_pc = sequence[0]['pc']
        previous_pc = sequence[current_frame_id - 1]['pc']
        if "firstandprevious".upper() in self.config.shape_aggregation.upper():
            template_pc, canonical_box = points_utils.getModel([first_pc, previous_pc],
                                                               [results_bbs[0], results_bbs[current_frame_id - 1]],
                                                               scale=self.config.model_bb_scale,
                                                               offset=self.config.model_bb_offset)
        elif "first".upper() in self.config.shape_aggregation.upper():
            template_pc, canonical_box = points_utils.cropAndCenterPC(first_pc, results_bbs[0],
                                                                      scale=self.config.model_bb_scale,
                                                                      offset=self.config.model_bb_offset)
        elif "previous".upper() in self.config.hape_aggregation.upper():
            template_pc, canonical_box = points_utils.cropAndCenterPC(previous_pc, results_bbs[current_frame_id - 1],
                                                                      scale=self.config.model_bb_scale,
                                                                      offset=self.config.model_bb_offset)
        elif "all".upper() in self.config.shape_aggregation.upper():
            template_pc, canonical_box = points_utils.getModel([frame["pc"] for frame in sequence[:current_frame_id]],
                                                               results_bbs,
                                                               scale=self.config.model_bb_scale,
                                                               offset=self.config.model_bb_offset)
        return template_pc, canonical_box

    def generate_search_area(self, sequence, current_frame_id, results_bbs):
        """
        generate search area for evaluating.

        :param sequence:
        :param current_frame_id:
        :param results_bbs:
        :return:
        """
        this_bb = sequence[current_frame_id]["3d_bbox"]
        this_pc = sequence[current_frame_id]["pc"]
        if ("previous_result".upper() in self.config.reference_BB.upper()):
            ref_bb = results_bbs[-1]
        elif ("previous_gt".upper() in self.config.reference_BB.upper()):
            previous_bb = sequence[current_frame_id - 1]["3d_bbox"]
            ref_bb = previous_bb
        elif ("current_gt".upper() in self.config.reference_BB.upper()):
            ref_bb = this_bb
        search_pc_crop = points_utils.generate_subwindow(this_pc, ref_bb,
                                                         scale=self.config.search_bb_scale,
                                                         offset=self.config.search_bb_offset)
        return search_pc_crop, ref_bb

    def prepare_input(self, template_pc, search_pc, template_box, *args, **kwargs):
        """
        construct input dict for evaluating
        :param template_pc:
        :param search_pc:
        :param template_box:
        :return:
        """
        template_points, idx_t = points_utils.regularize_pc(template_pc.points.T, self.config.template_size,
                                                            seed=1)
        search_points, idx_s = points_utils.regularize_pc(search_pc.points.T, self.config.search_size,
                                                          seed=1)
        template_points_torch = torch.tensor(template_points, device=self.device, dtype=torch.float32)
        search_points_torch = torch.tensor(search_points, device=self.device, dtype=torch.float32)
        data_dict = {
            'template_points': template_points_torch[None, ...],
            'search_points': search_points_torch[None, ...],
        }
        return data_dict

    def build_input_dict(self, sequence, frame_id, results_bbs, **kwargs):
        # preparing search area
        search_pc_crop, ref_bb = self.generate_search_area(sequence, frame_id, results_bbs)
        # update template
        template_pc, canonical_box = self.generate_template(sequence, frame_id, results_bbs)
        # construct input dict
        data_dict = self.prepare_input(template_pc, search_pc_crop, canonical_box)
        return data_dict, ref_bb


class MotionBaseModel(BaseModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.save_hyperparameters()

    def build_input_dict(self, sequence, frame_id, results_bbs):
        assert frame_id > 0, "no need to construct an input_dict at frame 0"

        prev_frame = sequence[frame_id - 1]
        this_frame = sequence[frame_id]
        prev_pc = prev_frame['pc']
        this_pc = this_frame['pc']
        ref_box = results_bbs[-1]
        prev_frame_pc = points_utils.generate_subwindow(prev_pc, ref_box,
                                                        scale=self.config.bb_scale,
                                                        offset=self.config.bb_offset)
        this_frame_pc = points_utils.generate_subwindow(this_pc, ref_box,
                                                        scale=self.config.bb_scale,
                                                        offset=self.config.bb_offset)

        canonical_box = points_utils.transform_box(ref_box, ref_box)
        prev_points, idx_prev = points_utils.regularize_pc(prev_frame_pc.points.T,
                                                           self.config.point_sample_size,
                                                           seed=1)

        this_points, idx_this = points_utils.regularize_pc(this_frame_pc.points.T,
                                                           self.config.point_sample_size,
                                                           seed=1)
        seg_mask_prev = geometry_utils.points_in_box(canonical_box, prev_points.T, 1.25).astype(float)

        # Here we use 0.2/0.8 instead of 0/1 to indicate that the previous box is not GT.
        # When boxcloud is used, the actual value of prior-targetness mask doesn't really matter.
        if frame_id != 1:
            seg_mask_prev[seg_mask_prev == 0] = 0.2
            seg_mask_prev[seg_mask_prev == 1] = 0.8
        seg_mask_this = np.full(seg_mask_prev.shape, fill_value=0.5)

        timestamp_prev = np.full((self.config.point_sample_size, 1), fill_value=0)
        timestamp_this = np.full((self.config.point_sample_size, 1), fill_value=0.1)
        prev_points = np.concatenate([prev_points, timestamp_prev, seg_mask_prev[:, None]], axis=-1)
        this_points = np.concatenate([this_points, timestamp_this, seg_mask_this[:, None]], axis=-1)

        stack_points = np.concatenate([prev_points, this_points], axis=0)

        data_dict = {"points": torch.tensor(stack_points[None, :], device=self.device, dtype=torch.float32),
                     }
        if getattr(self.config, 'box_aware', False):
            candidate_bc_prev = points_utils.get_point_to_box_distance(
                stack_points[:self.config.point_sample_size, :3], canonical_box)
            candidate_bc_this = np.zeros_like(candidate_bc_prev)
            candidate_bc = np.concatenate([candidate_bc_prev, candidate_bc_this], axis=0)
            data_dict.update({'candidate_bc': points_utils.np_to_torch_tensor(candidate_bc.astype('float32'),
                                                                              device=self.device)})
        return data_dict, results_bbs[-1]


================================================
FILE: models/bat.py
================================================
""" 
bat.py
Created by zenn at 2021/7/21 14:16
"""

import torch
from torch import nn
from models.backbone.pointnet import Pointnet_Backbone
from models.head.xcorr import BoxAwareXCorr
from models.head.rpn import P2BVoteNetRPN
from models import base_model
import torch.nn.functional as F
from datasets import points_utils
from pointnet2.utils import pytorch_utils as pt_utils


class BAT(base_model.MatchingBaseModel):
    def __init__(self, config=None, **kwargs):
        super().__init__(config, **kwargs)
        self.save_hyperparameters()
        self.backbone = Pointnet_Backbone(self.config.use_fps, self.config.normalize_xyz, return_intermediate=False)
        self.conv_final = nn.Conv1d(256, self.config.feature_channel, kernel_size=1)
        self.mlp_bc = (pt_utils.Seq(3 + self.config.feature_channel)
                       .conv1d(self.config.feature_channel, bn=True)
                       .conv1d(self.config.feature_channel, bn=True)
                       .conv1d(self.config.bc_channel, activation=None))

        self.xcorr = BoxAwareXCorr(feature_channel=self.config.feature_channel,
                                   hidden_channel=self.config.hidden_channel,
                                   out_channel=self.config.out_channel,
                                   k=self.config.k,
                                   use_search_bc=self.config.use_search_bc,
                                   use_search_feature=self.config.use_search_feature,
                                   bc_channel=self.config.bc_channel)
        self.rpn = P2BVoteNetRPN(self.config.feature_channel,
                                 vote_channel=self.config.vote_channel,
                                 num_proposal=self.config.num_proposal,
                                 normalize_xyz=self.config.normalize_xyz)


    def prepare_input(self, template_pc, search_pc, template_box):
        template_points, idx_t = points_utils.regularize_pc(template_pc.points.T, self.config.template_size,
                                                            seed=1)
        search_points, idx_s = points_utils.regularize_pc(search_pc.points.T, self.config.search_size,
                                                          seed=1)
        template_points_torch = torch.tensor(template_points, device=self.device, dtype=torch.float32)
        search_points_torch = torch.tensor(search_points, device=self.device, dtype=torch.float32)
        template_bc = points_utils.get_point_to_box_distance(template_points, template_box)
        template_bc_torch = torch.tensor(template_bc, device=self.device, dtype=torch.float32)
        data_dict = {
            'template_points': template_points_torch[None, ...],
            'search_points': search_points_torch[None, ...],
            'points2cc_dist_t': template_bc_torch[None, ...]
        }
        return data_dict

    def compute_loss(self, data, output):
        out_dict = super(BAT, self).compute_loss(data, output)
        search_bc = data['points2cc_dist_s']
        pred_search_bc = output['pred_search_bc']
        seg_label = data['seg_label']
        loss_bc = F.smooth_l1_loss(pred_search_bc, search_bc, reduction='none')
        loss_bc = torch.sum(loss_bc.mean(2) * seg_label) / (seg_label.sum() + 1e-6)
        out_dict["loss_bc"] = loss_bc
        return out_dict

    def forward(self, input_dict):
        """
        :param input_dict:
        {
        'template_points': template_points.astype('float32'),
        'search_points': search_points.astype('float32'),
        'box_label': np.array(search_bbox_reg).astype('float32'),
        'bbox_size': search_box.wlh,
        'seg_label': seg_label.astype('float32'),
        'points2cc_dist_t': template_bc,
        'points2cc_dist_s': search_bc,
        }

        :return:
        """
        template = input_dict['template_points']
        search = input_dict['search_points']
        template_bc = input_dict['points2cc_dist_t']
        M = template.shape[1]
        N = search.shape[1]

        # backbone
        template_xyz, template_feature, sample_idxs_t = self.backbone(template, [M // 2, M // 4, M // 8])
        search_xyz, search_feature, sample_idxs = self.backbone(search, [N // 2, N // 4, N // 8])
        template_feature = self.conv_final(template_feature)
        search_feature = self.conv_final(search_feature)
        # prepare bc
        pred_search_bc = self.mlp_bc(torch.cat([search_xyz.transpose(1, 2), search_feature], dim=1))  # B, 9, N // 8
        pred_search_bc = pred_search_bc.transpose(1, 2)
        sample_idxs_t = sample_idxs_t[:, :M // 8, None]
        template_bc = template_bc.gather(dim=1, index=sample_idxs_t.repeat(1, 1, self.config.bc_channel).long())
        # box-aware xcorr
        fusion_feature = self.xcorr(template_feature, search_feature, template_xyz, search_xyz, template_bc,
                                    pred_search_bc)
        # proposal generation
        estimation_boxes, estimation_cla, vote_xyz, center_xyzs = self.rpn(search_xyz, fusion_feature)
        end_points = {"estimation_boxes": estimation_boxes,
                      "vote_center": vote_xyz,
                      "pred_seg_score": estimation_cla,
                      "center_xyz": center_xyzs,
                      'sample_idxs': sample_idxs,
                      'estimation_cla': estimation_cla,
                      "vote_xyz": vote_xyz,
                      "pred_search_bc": pred_search_bc
                      }
        return end_points

    def training_step(self, batch, batch_idx):
        """
        {"estimation_boxes": estimation_boxs.transpose(1, 2).contiguous(),
                  "vote_center": vote_xyz,
                  "pred_seg_score": estimation_cla,
                  "center_xyz": center_xyzs,
                  "seed_idxs":
                  "seg_label"
                  "pred_search_bc": pred_search_bc
        }
        """
        end_points = self(batch)

        search_pc = batch['points2cc_dist_s']
        estimation_cla = end_points['estimation_cla']  # B,N
        N = estimation_cla.shape[1]
        seg_label = batch['seg_label']
        sample_idxs = end_points['sample_idxs']  # B,N
        seg_label = seg_label.gather(dim=1, index=sample_idxs[:, :N].long())  # B,N
        search_pc = search_pc.gather(dim=1, index=sample_idxs[:, :N, None].repeat(1, 1, self.config.bc_channel).long())
        # update label
        batch['seg_label'] = seg_label
        batch['points2cc_dist_s'] = search_pc
        # compute loss
        loss_dict = self.compute_loss(batch, end_points)
        loss = loss_dict['loss_objective'] * self.config.objectiveness_weight \
               + loss_dict['loss_box'] * self.config.box_weight \
               + loss_dict['loss_seg'] * self.config.seg_weight \
               + loss_dict['loss_vote'] * self.config.vote_weight \
               + loss_dict['loss_bc'] * self.config.bc_weight

        # log
        self.log('loss/train', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=False)
        self.log('loss_box/train', loss_dict['loss_box'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_seg/train', loss_dict['loss_seg'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_vote/train', loss_dict['loss_vote'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_bc/train', loss_dict['loss_bc'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_objective/train', loss_dict['loss_objective'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)

        self.logger.experiment.add_scalars('loss', {'loss_total': loss.item(),
                                                    'loss_box': loss_dict['loss_box'].item(),
                                                    'loss_seg': loss_dict['loss_seg'].item(),
                                                    'loss_vote': loss_dict['loss_vote'].item(),
                                                    'loss_objective': loss_dict['loss_objective'].item(),
                                                    'loss_bc': loss_dict['loss_bc'].item()},
                                           global_step=self.global_step)

        return loss


================================================
FILE: models/head/rpn.py
================================================
""" 
rpn.py
Created by zenn at 2021/5/8 20:55
"""
import torch
from torch import nn
from pointnet2.utils import pytorch_utils as pt_utils

from pointnet2.utils.pointnet2_modules import PointnetSAModule


class P2BVoteNetRPN(nn.Module):

    def __init__(self, feature_channel, vote_channel=256, num_proposal=64, normalize_xyz=False):
        super().__init__()
        self.num_proposal = num_proposal
        self.FC_layer_cla = (
            pt_utils.Seq(feature_channel)
                .conv1d(feature_channel, bn=True)
                .conv1d(feature_channel, bn=True)
                .conv1d(1, activation=None))
        self.vote_layer = (
            pt_utils.Seq(3 + feature_channel)
                .conv1d(feature_channel, bn=True)
                .conv1d(feature_channel, bn=True)
                .conv1d(3 + feature_channel, activation=None))

        self.vote_aggregation = PointnetSAModule(
            radius=0.3,
            nsample=16,
            mlp=[1 + feature_channel, vote_channel, vote_channel, vote_channel],
            use_xyz=True,
            normalize_xyz=normalize_xyz)

        self.FC_proposal = (
            pt_utils.Seq(vote_channel)
                .conv1d(vote_channel, bn=True)
                .conv1d(vote_channel, bn=True)
                .conv1d(3 + 1 + 1, activation=None))

    def forward(self, xyz, feature):
        """

        :param xyz: B,N,3
        :param feature: B,f,N
        :return: B,N,4+1 (xyz,theta,targetnessscore)
        """
        estimation_cla = self.FC_layer_cla(feature).squeeze(1)
        score = estimation_cla.sigmoid()

        xyz_feature = torch.cat((xyz.transpose(1, 2).contiguous(), feature), dim=1)

        offset = self.vote_layer(xyz_feature)
        vote = xyz_feature + offset
        vote_xyz = vote[:, 0:3, :].transpose(1, 2).contiguous()
        vote_feature = vote[:, 3:, :]

        vote_feature = torch.cat((score.unsqueeze(1), vote_feature), dim=1)

        center_xyzs, proposal_features = self.vote_aggregation(vote_xyz, vote_feature, self.num_proposal)
        proposal_offsets = self.FC_proposal(proposal_features)
        estimation_boxes = torch.cat(
            (proposal_offsets[:, 0:3, :] + center_xyzs.transpose(1, 2).contiguous(), proposal_offsets[:, 3:5, :]),
            dim=1)

        estimation_boxes = estimation_boxes.transpose(1, 2).contiguous()
        return estimation_boxes, estimation_cla, vote_xyz, center_xyzs


================================================
FILE: models/head/xcorr.py
================================================
# Created by zenn at 2021/5/8
import torch
from torch import nn
from pointnet2.utils import pytorch_utils as pt_utils
from pointnet2.utils import pointnet2_utils

import torch.nn.functional as F


class BaseXCorr(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel):
        super().__init__()
        self.cosine = nn.CosineSimilarity(dim=1)
        self.mlp = pt_utils.SharedMLP([in_channel, hidden_channel, hidden_channel, hidden_channel], bn=True)
        self.fea_layer = (pt_utils.Seq(hidden_channel)
                          .conv1d(hidden_channel, bn=True)
                          .conv1d(out_channel, activation=None))


class P2B_XCorr(BaseXCorr):
    def __init__(self, feature_channel, hidden_channel, out_channel):
        mlp_in_channel = feature_channel + 4
        super().__init__(mlp_in_channel, hidden_channel, out_channel)

    def forward(self, template_feature, search_feature, template_xyz):
        """

        :param template_feature: B,f,M
        :param search_feature: B,f,N
        :param template_xyz: B,M,3
        :return:
        """
        B = template_feature.size(0)
        f = template_feature.size(1)
        n1 = template_feature.size(2)
        n2 = search_feature.size(2)
        final_out_cla = self.cosine(template_feature.unsqueeze(-1).expand(B, f, n1, n2),
                                    search_feature.unsqueeze(2).expand(B, f, n1, n2))  # B,n1,n2

        fusion_feature = torch.cat(
            (final_out_cla.unsqueeze(1), template_xyz.transpose(1, 2).contiguous().unsqueeze(-1).expand(B, 3, n1, n2)),
            dim=1)  # B,1+3,n1,n2

        fusion_feature = torch.cat((fusion_feature, template_feature.unsqueeze(-1).expand(B, f, n1, n2)),
                                   dim=1)  # B,1+3+f,n1,n2

        fusion_feature = self.mlp(fusion_feature)

        fusion_feature = F.max_pool2d(fusion_feature, kernel_size=[fusion_feature.size(2), 1])  # B, f, 1, n2
        fusion_feature = fusion_feature.squeeze(2)  # B, f, n2
        fusion_feature = self.fea_layer(fusion_feature)

        return fusion_feature


class BoxAwareXCorr(BaseXCorr):
    def __init__(self, feature_channel, hidden_channel, out_channel, k=8, use_search_bc=False, use_search_feature=False,
                 bc_channel=9):
        self.k = k
        self.use_search_bc = use_search_bc
        self.use_search_feature = use_search_feature
        mlp_in_channel = feature_channel + 3 + bc_channel
        if use_search_bc: mlp_in_channel += bc_channel
        if use_search_feature: mlp_in_channel += feature_channel
        super(BoxAwareXCorr, self).__init__(mlp_in_channel, hidden_channel, out_channel)

    def forward(self, template_feature, search_feature, template_xyz,
                search_xyz=None, template_bc=None, search_bc=None):
        """

        :param template_feature: B,f,M
        :param search_feature: B,f,N
        :param template_xyz: B,M,3
        :param search_xyz: B.N,3
        :param template_bc: B,M,9
        :param search_bc: B.N,9
        :param args:
        :param kwargs:
        :return:
        """
        dist_matrix = torch.cdist(template_bc, search_bc)  # B, M, N
        template_xyz_feature_box = torch.cat([template_xyz.transpose(1, 2).contiguous(),
                                              template_bc.transpose(1, 2).contiguous(),
                                              template_feature], dim=1)
        # search_xyz_feature = torch.cat([search_xyz.transpose(1, 2).contiguous(), search_feature], dim=1)

        top_k_nearest_idx_b = torch.argsort(dist_matrix, dim=1)[:, :self.k, :]  # B, K, N
        top_k_nearest_idx_b = top_k_nearest_idx_b.transpose(1, 2).contiguous().int()  # B, N, K
        correspondences_b = pointnet2_utils.grouping_operation(template_xyz_feature_box,
                                                               top_k_nearest_idx_b)  # B,3+9+D,N,K
        if self.use_search_bc:
            search_bc_expand = search_bc.transpose(1, 2).unsqueeze(dim=-1).repeat(1, 1, 1, self.K)  # B,9,N,K
            correspondences_b = torch.cat([search_bc_expand, correspondences_b], dim=1)
        if self.use_search_feature:
            search_feature_expand = search_feature.unsqueeze(dim=-1).repeat(1, 1, 1, self.K)  # B,D,N,K
            correspondences_b = torch.cat([search_feature_expand, correspondences_b], dim=1)

        ## correspondences fusion head
        fusion_feature = self.mlp(correspondences_b)  # B,D,N,K
        fusion_feature, _ = torch.max(fusion_feature, dim=-1)  # B,D,N,1
        fusion_feature = self.fea_layer(fusion_feature.squeeze(dim=-1))  # B,D,N

        return fusion_feature


================================================
FILE: models/m2track.py
================================================
"""
m2track.py
Created by zenn at 2021/11/24 13:10
"""
from datasets import points_utils
from models import base_model
from models.backbone.pointnet import MiniPointNet, SegPointNet

import torch
from torch import nn
import torch.nn.functional as F

from utils.metrics import estimateOverlap, estimateAccuracy
from torchmetrics import Accuracy


class M2TRACK(base_model.MotionBaseModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.seg_acc = Accuracy(num_classes=2, average='none')

        self.box_aware = getattr(config, 'box_aware', False)
        self.use_motion_cls = getattr(config, 'use_motion_cls', True)
        self.use_second_stage = getattr(config, 'use_second_stage', True)
        self.use_prev_refinement = getattr(config, 'use_prev_refinement', True)
        self.seg_pointnet = SegPointNet(input_channel=3 + 1 + 1 + (9 if self.box_aware else 0),
                                        per_point_mlp1=[64, 64, 64, 128, 1024],
                                        per_point_mlp2=[512, 256, 128, 128],
                                        output_size=2 + (9 if self.box_aware else 0))
        self.mini_pointnet = MiniPointNet(input_channel=3 + 1 + (9 if self.box_aware else 0),
                                          per_point_mlp=[64, 128, 256, 512],
                                          hidden_mlp=[512, 256],
                                          output_size=-1)
        if self.use_second_stage:
            self.mini_pointnet2 = MiniPointNet(input_channel=3 + (9 if self.box_aware else 0),
                                               per_point_mlp=[64, 128, 256, 512],
                                               hidden_mlp=[512, 256],
                                               output_size=-1)

            self.box_mlp = nn.Sequential(nn.Linear(256, 128),
                                         nn.BatchNorm1d(128),
                                         nn.ReLU(),
                                         nn.Linear(128, 128),
                                         nn.BatchNorm1d(128),
                                         nn.ReLU(),
                                         nn.Linear(128, 4))
        if self.use_prev_refinement:
            self.final_mlp = nn.Sequential(nn.Linear(256, 128),
                                           nn.BatchNorm1d(128),
                                           nn.ReLU(),
                                           nn.Linear(128, 128),
                                           nn.BatchNorm1d(128),
                                           nn.ReLU(),
                                           nn.Linear(128, 4))
        if self.use_motion_cls:
            self.motion_state_mlp = nn.Sequential(nn.Linear(256, 128),
                                                  nn.BatchNorm1d(128),
                                                  nn.ReLU(),
                                                  nn.Linear(128, 128),
                                                  nn.BatchNorm1d(128),
                                                  nn.ReLU(),
                                                  nn.Linear(128, 2))
            self.motion_acc = Accuracy(num_classes=2, average='none')

        self.motion_mlp = nn.Sequential(nn.Linear(256, 128),
                                        nn.BatchNorm1d(128),
                                        nn.ReLU(),
                                        nn.Linear(128, 128),
                                        nn.BatchNorm1d(128),
                                        nn.ReLU(),
                                        nn.Linear(128, 4))

    def forward(self, input_dict):
        """
        Args:
            input_dict: {
            "points": (B,N,3+1+1)
            "candidate_bc": (B,N,9)

        }

        Returns: B,4

        """
        output_dict = {}
        x = input_dict["points"].transpose(1, 2)
        if self.box_aware:
            candidate_bc = input_dict["candidate_bc"].transpose(1, 2)
            x = torch.cat([x, candidate_bc], dim=1)

        B, _, N = x.shape

        seg_out = self.seg_pointnet(x)
        seg_logits = seg_out[:, :2, :]  # B,2,N
        pred_cls = torch.argmax(seg_logits, dim=1, keepdim=True)  # B,1,N
        mask_points = x[:, :4, :] * pred_cls
        mask_xyz_t0 = mask_points[:, :3, :N // 2]  # B,3,N//2
        mask_xyz_t1 = mask_points[:, :3, N // 2:]
        if self.box_aware:
            pred_bc = seg_out[:, 2:, :]
            mask_pred_bc = pred_bc * pred_cls
            # mask_pred_bc_t0 = mask_pred_bc[:, :, :N // 2]  # B,9,N//2
            # mask_pred_bc_t1 = mask_pred_bc[:, :, N // 2:]
            mask_points = torch.cat([mask_points, mask_pred_bc], dim=1)
            output_dict['pred_bc'] = pred_bc.transpose(1, 2)

        point_feature = self.mini_pointnet(mask_points)

        # motion state prediction
        motion_pred = self.motion_mlp(point_feature)  # B,4
        if self.use_motion_cls:
            motion_state_logits = self.motion_state_mlp(point_feature)  # B,2
            motion_mask = torch.argmax(motion_state_logits, dim=1, keepdim=True)  # B,1
            motion_pred_masked = motion_pred * motion_mask
            output_dict['motion_cls'] = motion_state_logits
        else:
            motion_pred_masked = motion_pred
        # previous bbox refinement
        if self.use_prev_refinement:
            prev_boxes = self.final_mlp(point_feature)  # previous bb, B,4
            output_dict["estimation_boxes_prev"] = prev_boxes[:, :4]
        else:
            prev_boxes = torch.zeros_like(motion_pred)

        # 1st stage prediction
        aux_box = points_utils.get_offset_box_tensor(prev_boxes, motion_pred_masked)

        # 2nd stage refinement
        if self.use_second_stage:
            mask_xyz_t0_2_t1 = points_utils.get_offset_points_tensor(mask_xyz_t0.transpose(1, 2),
                                                                     prev_boxes[:, :4],
                                                                     motion_pred_masked).transpose(1, 2)  # B,3,N//2
            mask_xyz_t01 = torch.cat([mask_xyz_t0_2_t1, mask_xyz_t1], dim=-1)  # B,3,N

            # transform to the aux_box coordinate system
            mask_xyz_t01 = points_utils.remove_transform_points_tensor(mask_xyz_t01.transpose(1, 2),
                                                                       aux_box).transpose(1, 2)

            if self.box_aware:
                mask_xyz_t01 = torch.cat([mask_xyz_t01, mask_pred_bc], dim=1)
            output_offset = self.box_mlp(self.mini_pointnet2(mask_xyz_t01))  # B,4
            output = points_utils.get_offset_box_tensor(aux_box, output_offset)
            output_dict["estimation_boxes"] = output
        else:
            output_dict["estimation_boxes"] = aux_box
        output_dict.update({"seg_logits": seg_logits,
                            "motion_pred": motion_pred,
                            'aux_estimation_boxes': aux_box,
                            })

        return output_dict

    def compute_loss(self, data, output):
        loss_total = 0.0
        loss_dict = {}
        aux_estimation_boxes = output['aux_estimation_boxes']  # B,4
        motion_pred = output['motion_pred']  # B,4
        seg_logits = output['seg_logits']
        with torch.no_grad():
            seg_label = data['seg_label']
            box_label = data['box_label']
            box_label_prev = data['box_label_prev']
            motion_label = data['motion_label']
            motion_state_label = data['motion_state_label']
            center_label = box_label[:, :3]
            angle_label = torch.sin(box_label[:, 3])
            center_label_prev = box_label_prev[:, :3]
            angle_label_prev = torch.sin(box_label_prev[:, 3])
            center_label_motion = motion_label[:, :3]
            angle_label_motion = torch.sin(motion_label[:, 3])

        loss_seg = F.cross_entropy(seg_logits, seg_label, weight=torch.tensor([0.5, 2.0]).cuda())
        if self.use_motion_cls:
            motion_cls = output['motion_cls']  # B,2
            loss_motion_cls = F.cross_entropy(motion_cls, motion_state_label)
            loss_total += loss_motion_cls * self.config.motion_cls_seg_weight
            loss_dict['loss_motion_cls'] = loss_motion_cls

            loss_center_motion = F.smooth_l1_loss(motion_pred[:, :3], center_label_motion, reduction='none')
            loss_center_motion = (motion_state_label * loss_center_motion.mean(dim=1)).sum() / (
                    motion_state_label.sum() + 1e-6)
            loss_angle_motion = F.smooth_l1_loss(torch.sin(motion_pred[:, 3]), angle_label_motion, reduction='none')
            loss_angle_motion = (motion_state_label * loss_angle_motion).sum() / (motion_state_label.sum() + 1e-6)
        else:
            loss_center_motion = F.smooth_l1_loss(motion_pred[:, :3], center_label_motion)
            loss_angle_motion = F.smooth_l1_loss(torch.sin(motion_pred[:, 3]), angle_label_motion)

        if self.use_second_stage:
            estimation_boxes = output['estimation_boxes']  # B,4
            loss_center = F.smooth_l1_loss(estimation_boxes[:, :3], center_label)
            loss_angle = F.smooth_l1_loss(torch.sin(estimation_boxes[:, 3]), angle_label)
            loss_total += 1 * (loss_center * self.config.center_weight + loss_angle * self.config.angle_weight)
            loss_dict["loss_center"] = loss_center
            loss_dict["loss_angle"] = loss_angle
        if self.use_prev_refinement:
            estimation_boxes_prev = output['estimation_boxes_prev']  # B,4
            loss_center_prev = F.smooth_l1_loss(estimation_boxes_prev[:, :3], center_label_prev)
            loss_angle_prev = F.smooth_l1_loss(torch.sin(estimation_boxes_prev[:, 3]), angle_label_prev)
            loss_total += (loss_center_prev * self.config.center_weight + loss_angle_prev * self.config.angle_weight)
            loss_dict["loss_center_prev"] = loss_center_prev
            loss_dict["loss_angle_prev"] = loss_angle_prev

        loss_center_aux = F.smooth_l1_loss(aux_estimation_boxes[:, :3], center_label)

        loss_angle_aux = F.smooth_l1_loss(torch.sin(aux_estimation_boxes[:, 3]), angle_label)

        loss_total += loss_seg * self.config.seg_weight \
                      + 1 * (loss_center_aux * self.config.center_weight + loss_angle_aux * self.config.angle_weight) \
                      + 1 * (
                              loss_center_motion * self.config.center_weight + loss_angle_motion * self.config.angle_weight)
        loss_dict.update({
            "loss_total": loss_total,
            "loss_seg": loss_seg,
            "loss_center_aux": loss_center_aux,
            "loss_center_motion": loss_center_motion,
            "loss_angle_aux": loss_angle_aux,
            "loss_angle_motion": loss_angle_motion,
        })
        if self.box_aware:
            prev_bc = data['prev_bc']
            this_bc = data['this_bc']
            bc_label = torch.cat([prev_bc, this_bc], dim=1)
            pred_bc = output['pred_bc']
            loss_bc = F.smooth_l1_loss(pred_bc, bc_label)
            loss_total += loss_bc * self.config.bc_weight
            loss_dict.update({
                "loss_total": loss_total,
                "loss_bc": loss_bc
            })

        return loss_dict

    def training_step(self, batch, batch_idx):
        """
        Args:
            batch: {
            "points": stack_frames, (B,N,3+9+1)
            "seg_label": stack_label,
            "box_label": np.append(this_gt_bb_transform.center, theta),
            "box_size": this_gt_bb_transform.wlh
        }
        Returns:

        """
        output = self(batch)
        loss_dict = self.compute_loss(batch, output)
        loss = loss_dict['loss_total']

        # log
        seg_acc = self.seg_acc(torch.argmax(output['seg_logits'], dim=1, keepdim=False), batch['seg_label'])
        self.log('seg_acc_background/train', seg_acc[0], on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log('seg_acc_foreground/train', seg_acc[1], on_step=True, on_epoch=True, prog_bar=False, logger=True)
        if self.use_motion_cls:
            motion_acc = self.motion_acc(torch.argmax(output['motion_cls'], dim=1, keepdim=False),
                                         batch['motion_state_label'])
            self.log('motion_acc_static/train', motion_acc[0], on_step=True, on_epoch=True, prog_bar=False, logger=True)
            self.log('motion_acc_dynamic/train', motion_acc[1], on_step=True, on_epoch=True, prog_bar=False,
                     logger=True)

        log_dict = {k: v.item() for k, v in loss_dict.items()}

        self.logger.experiment.add_scalars('loss', log_dict,
                                           global_step=self.global_step)
        return loss




================================================
FILE: models/p2b.py
================================================
""" 
p2b.py
Created by zenn at 2021/5/9 13:47
"""

from torch import nn
from models.backbone.pointnet import Pointnet_Backbone
from models.head.xcorr import P2B_XCorr
from models.head.rpn import P2BVoteNetRPN
from models import base_model


class P2B(base_model.MatchingBaseModel):
    def __init__(self, config=None, **kwargs):
        super().__init__(config, **kwargs)
        self.save_hyperparameters()
        self.backbone = Pointnet_Backbone(self.config.use_fps, self.config.normalize_xyz, return_intermediate=False)
        self.conv_final = nn.Conv1d(256, self.config.feature_channel, kernel_size=1)

        self.xcorr = P2B_XCorr(feature_channel=self.config.feature_channel,
                               hidden_channel=self.config.hidden_channel,
                               out_channel=self.config.out_channel)
        self.rpn = P2BVoteNetRPN(self.config.feature_channel,
                                 vote_channel=self.config.vote_channel,
                                 num_proposal=self.config.num_proposal,
                                 normalize_xyz=self.config.normalize_xyz)

    def forward(self, input_dict):
        """
        :param input_dict:
        {
        'template_points': template_points.astype('float32'),
        'search_points': search_points.astype('float32'),
        'box_label': np.array(search_bbox_reg).astype('float32'),
        'bbox_size': search_box.wlh,
        'seg_label': seg_label.astype('float32'),
        }

        :return:
        """
        template = input_dict['template_points']
        search = input_dict['search_points']
        M = template.shape[1]
        N = search.shape[1]
        template_xyz, template_feature, _ = self.backbone(template, [M // 2, M // 4, M // 8])
        search_xyz, search_feature, sample_idxs = self.backbone(search, [N // 2, N // 4, N // 8])
        template_feature = self.conv_final(template_feature)
        search_feature = self.conv_final(search_feature)
        fusion_feature = self.xcorr(template_feature, search_feature, template_xyz)
        estimation_boxes, estimation_cla, vote_xyz, center_xyzs = self.rpn(search_xyz, fusion_feature)
        end_points = {"estimation_boxes": estimation_boxes,
                      "vote_center": vote_xyz,
                      "pred_seg_score": estimation_cla,
                      "center_xyz": center_xyzs,
                      'sample_idxs': sample_idxs,
                      'estimation_cla': estimation_cla,
                      "vote_xyz": vote_xyz,
                      }
        return end_points

    def training_step(self, batch, batch_idx):
        """
        {"estimation_boxes": estimation_boxs.transpose(1, 2).contiguous(),
                  "vote_center": vote_xyz,
                  "pred_seg_score": estimation_cla,
                  "center_xyz": center_xyzs,
                  "seed_idxs":
                  "seg_label"
        }
        """
        end_points = self(batch)
        estimation_cla = end_points['estimation_cla']  # B,N
        N = estimation_cla.shape[1]
        seg_label = batch['seg_label']
        sample_idxs = end_points['sample_idxs']  # B,N
        # update label
        seg_label = seg_label.gather(dim=1, index=sample_idxs[:, :N].long())  # B,N
        batch["seg_label"] = seg_label
        # compute loss
        loss_dict = self.compute_loss(batch, end_points)
        loss = loss_dict['loss_objective'] * self.config.objectiveness_weight \
               + loss_dict['loss_box'] * self.config.box_weight \
               + loss_dict['loss_seg'] * self.config.seg_weight \
               + loss_dict['loss_vote'] * self.config.vote_weight
        self.log('loss/train', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=False)
        self.log('loss_box/train', loss_dict['loss_box'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_seg/train', loss_dict['loss_seg'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_vote/train', loss_dict['loss_vote'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.log('loss_objective/train', loss_dict['loss_objective'].item(), on_step=True, on_epoch=True, prog_bar=True,
                 logger=False)
        self.logger.experiment.add_scalars('loss', {'loss_total': loss.item(),
                                                    'loss_box': loss_dict['loss_box'].item(),
                                                    'loss_seg': loss_dict['loss_seg'].item(),
                                                    'loss_vote': loss_dict['loss_vote'].item(),
                                                    'loss_objective': loss_dict['loss_objective'].item()},
                                           global_step=self.global_step)

        return loss


================================================
FILE: pointnet2/__init__.py
================================================
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)

__version__ = "2.1.1"

try:
    __POINTNET2_SETUP__
except NameError:
    __POINTNET2_SETUP__ = False

if not __POINTNET2_SETUP__:
    from pointnet2 import utils


================================================
FILE: pointnet2/utils/__init__.py
================================================
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)
from . import pointnet2_utils
from . import pointnet2_modules


================================================
FILE: pointnet2/utils/linalg_utils.py
================================================
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)
import torch
from enum import Enum
import numpy as np

PDist2Order = Enum("PDist2Order", "d_first d_second")


def pdist2(X, Z=None, order=PDist2Order.d_second):
    # type: (torch.Tensor, torch.Tensor, PDist2Order) -> torch.Tensor
    r""" Calculates the pairwise distance between X and Z

    D[b, i, j] = l2 distance X[b, i] and Z[b, j]

    Parameters
    ---------
    X : torch.Tensor
        X is a (B, N, d) tensor.  There are B batches, and N vectors of dimension d
    Z: torch.Tensor
        Z is a (B, M, d) tensor.  If Z is None, then Z = X

    Returns
    -------
    torch.Tensor
        Distance matrix is size (B, N, M)
    """

    if order == PDist2Order.d_second:
        if X.dim() == 2:
            X = X.unsqueeze(0)
        if Z is None:
            Z = X
            G = np.matmul(X, Z.transpose(-2, -1))
            S = (X * X).sum(-1, keepdim=True)
            R = S.transpose(-2, -1)
        else:
            if Z.dim() == 2:
                Z = Z.unsqueeze(0)
            G = np.matmul(X, Z.transpose(-2, -1))
            S = (X * X).sum(-1, keepdim=True)
            R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
    else:
        if X.dim() == 2:
            X = X.unsqueeze(0)
        if Z is None:
            Z = X
            G = np.matmul(X.transpose(-2, -1), Z)
            R = (X * X).sum(-2, keepdim=True)
            S = R.transpose(-2, -1)
        else:
            if Z.dim() == 2:
                Z = Z.unsqueeze(0)
            G = np.matmul(X.transpose(-2, -1), Z)
            S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
            R = (Z * Z).sum(-2, keepdim=True)

    return torch.abs(R + S - 2 * G).squeeze(0)


def pdist2_slow(X, Z=None):
    if Z is None:
        Z = X
    D = torch.zeros(X.size(0), X.size(2), Z.size(2))

    for b in range(D.size(0)):
        for i in range(D.size(1)):
            for j in range(D.size(2)):
                D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j])
    return D


if __name__ == "__main__":
    X = torch.randn(2, 3, 5)
    Z = torch.randn(2, 3, 3)

    print(pdist2(X, order=PDist2Order.d_first))
    print(pdist2_slow(X))
    print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X)))


================================================
FILE: pointnet2/utils/pointnet2_modules.py
================================================
""" PointNet++ Layers
Modified by Zenn
Date: Feb 2021
"""
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from pointnet2.utils import pytorch_utils as pt_utils

from pointnet2.utils import pointnet2_utils

if False:
    # Workaround for type hints without depending on the `typing` module
    from typing import *


class _PointnetSAModuleBase(nn.Module):
    def __init__(self, use_fps=False):
        super(_PointnetSAModuleBase, self).__init__()
        self.groupers = None
        self.mlps = None
        self.use_fps = use_fps

    def forward(self, xyz, features, npoint, return_idx=False):
        # modified to return sample idxs
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B,  \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        """

        self.npoint = npoint
        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        if self.use_fps:
            sample_idxs = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
        else:
            sample_idxs = torch.arange(self.npoint).repeat(xyz.size(0), 1).int().cuda()

        new_xyz = (
            pointnet2_utils.gather_operation(xyz_flipped, sample_idxs)
                .transpose(1, 2)
                .contiguous()
        )

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features
            )  # (B, C, npoint, nsample)

            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(
                new_features, kernel_size=[1, new_features.size(3)]
            )  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)
        if return_idx:
            return new_xyz, torch.cat(new_features_list, dim=1), sample_idxs
        else:
            return new_xyz, torch.cat(new_features_list, dim=1)


class PointnetSAModuleMSG(_PointnetSAModuleBase):
    r"""Pointnet set abstrction layer with multiscale grouping

    Parameters
    ----------
    npoint : int
        Number of features
    radii : list of float32
        list of radii to group with
    nsamples : list of int32
        Number of samples in each ball query
    mlps : list of list of int32
        Spec of the pointnet before the global max_pool for each scale
    bn : bool
        Use batchnorm
    """

    def __init__(self, radii, nsamples, mlps, bn=True, use_xyz=True, use_fps=False, normalize_xyz=False):
        # type: (PointnetSAModuleMSG, List[float],List[int], List[List[int]], bool, bool,bool) -> None
        super(PointnetSAModuleMSG, self).__init__(use_fps=use_fps)

        assert len(radii) == len(nsamples) == len(mlps)

        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, normalize_xyz=normalize_xyz))

            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn))


class PointnetSAModule(PointnetSAModuleMSG):
    r"""Pointnet set abstrction layer

    Parameters
    ----------
    npoint : int
        Number of features
    radius : float
        Radius of ball
    nsample : int
        Number of samples in the ball query
    mlp : list
        Spec of the pointnet before the global max_pool
    bn : bool
        Use batchnorm
    """

    def __init__(
            self, mlp, radius=None, nsample=None, bn=True, use_xyz=True, use_fps=False, normalize_xyz=False
    ):
        # type: (PointnetSAModule, List[int], float, int, bool, bool, bool,bool) -> None
        super(PointnetSAModule, self).__init__(
            mlps=[mlp],
            radii=[radius],
            nsamples=[nsample],
            bn=bn,
            use_xyz=use_xyz,
            use_fps=use_fps,
            normalize_xyz=normalize_xyz
        )


class PointnetFPModule(nn.Module):
    r"""Propigates the features of one set to another

    Parameters
    ----------
    mlp : list
        Pointnet module parameters
    bn : bool
        Use batchnorm
    """

    def __init__(self, mlp, bn=True):
        # type: (PointnetFPModule, List[int], bool) -> None
        super(PointnetFPModule, self).__init__()
        self.mlp = pt_utils.SharedMLP(mlp, bn=bn)

    def forward(self, unknown, known, unknow_feats, known_feats):
        # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight
            )
        else:
            interpolated_feats = known_feats.expand(
                *(known_feats.size()[0:2] + [unknown.size(1)])
            )

        if unknow_feats is not None:
            new_features = torch.cat(
                [interpolated_feats, unknow_feats], dim=1
            )  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)


class FlowEmbedding(nn.Module):
    """Modified from https://github.com/hyangwinter/flownet3d_pytorch/blob/master/util.py"""

    def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn=True):
        super(FlowEmbedding, self).__init__()
        self.radius = radius
        self.nsample = nsample
        self.knn = knn
        self.pooling = pooling
        self.corr_func = corr_func
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        if corr_func is 'concat':
            last_channel = in_channel * 2 + 3
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, feature1, feature2):
        """
        Input:
            xyz1: (batch_size, npoint, 3)
            xyz2: (batch_size, npoint, 3)
            feat1: (batch_size, channel, npoint)
            feat2: (batch_size, channel, npoint)
        Output:
            xyz1: (batch_size, npoint, 3)
            feat1_new: (batch_size, mlp[-1], npoint)
        """
        xyz1_t = xyz1.permute(0, 2, 1).contiguous()
        xyz2_t = xyz2.permute(0, 2, 1).contiguous()
        # feature1 = feature1.permute(0, 2, 1).contiguous()
        # feature2 = feature2.permute(0, 2, 1).contiguous()

        B, N, C = xyz1.shape
        if self.knn:
            idx = pointnet2_utils.knn_point(self.nsample, xyz1, xyz2)  # (B, npoint, nsample)
        else:
            idx, cnt = pointnet2_utils.ball_query(self.radius, self.nsample, xyz2, xyz1)  # (B, npoint, nsample)

        xyz2_grouped = pointnet2_utils.grouping_operation(xyz2_t, idx)  # (B, 3, npoint, nsample)
        pos_diff = xyz2_grouped - xyz1_t.view(B, -1, N, 1)  # (B, 3, npoint, nsample)

        feat2_grouped = pointnet2_utils.grouping_operation(feature2, idx)  # [B, C, npoint, nsample]
        if self.corr_func == 'concat':
            feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim=1)

        feat1_new = torch.cat([pos_diff, feat_diff], dim=1)  # [B, 2*C+3,npoint, nsample]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            feat1_new = F.relu(bn(conv(feat1_new)))

        feat1_new = torch.max(feat1_new, -1)[0]  # [B, mlp[-1], npoint]
        return xyz1, feat1_new


class PointNetSetUpConv(nn.Module):
    def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn=True):
        super(PointNetSetUpConv, self).__init__()
        self.nsample = nsample
        self.radius = radius
        self.knn = knn
        self.mlp1_convs = nn.ModuleList()
        self.mlp2_convs = nn.ModuleList()
        last_channel = f2_channel + 3
        for out_channel in mlp:
            self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False),
                                                 nn.BatchNorm2d(out_channel),
                                                 nn.ReLU(inplace=False)))
            last_channel = out_channel
        if len(mlp) is not 0:
            last_channel = mlp[-1] + f1_channel
        else:
            last_channel = last_channel + f1_channel
        for out_channel in mlp2:
            self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False),
                                                 nn.BatchNorm1d(out_channel),
                                                 nn.ReLU(inplace=False)))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, feature1, feature2):
        """
            Feature propagation from xyz2 (less points) to xyz1 (more points)
        Inputs:
            xyz1: (batch_size, npoint1, 3)
            xyz2: (batch_size, npoint2, 3)
            feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers)
            feat2: (batch_size, channel2, npoint2) features for xyz2 points
        Output:
            feat1_new: (batch_size, mlp[-1] or mlp2[-1] or channel1+3, npoint2)
            TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating.
        """
        xyz1_t = xyz1.permute(0, 2, 1).contiguous()
        xyz2_t = xyz2.permute(0, 2, 1).contiguous()
        # feature1 = feature1.permute(0, 2, 1).contiguous()
        # feature2 = feature2.permute(0, 2, 1).contiguous()
        B, C, N = xyz1_t.shape
        if self.knn:
            idx = pointnet2_utils.knn_point(self.nsample, xyz1, xyz2)  # (B, npoint1, nsample)
        else:
            idx, cnt = pointnet2_utils.ball_query(self.radius, self.nsample, xyz2, xyz1)  # (B, npoint1, nsample)

        xyz2_grouped = pointnet2_utils.grouping_operation(xyz2_t, idx)
        pos_diff = xyz2_grouped - xyz1_t.view(B, -1, N, 1)  # [B,3,N1,S]

        feat2_grouped = pointnet2_utils.grouping_operation(feature2, idx)
        feat_new = torch.cat([feat2_grouped, pos_diff], dim=1)  # [B,C1+3,N1,S]
        for conv in self.mlp1_convs:
            feat_new = conv(feat_new)
        # max pooling
        feat_new = feat_new.max(-1)[0]  # [B,mlp1[-1],N1]
        # concatenate feature in early layer
        if feature1 is not None:
            feat_new = torch.cat([feat_new, feature1], dim=1)
        # feat_new = feat_new.view(B,-1,N,1)
        for conv in self.mlp2_convs:
            feat_new = conv(feat_new)

        return feat_new


if __name__ == "__main__":
    from torch.autograd import Variable

    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True)
    xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True)

    test_module = PointnetSAModuleMSG(
        npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]]
    )
    test_module.cuda()
    print(test_module(xyz, xyz_feats))

    #  test_module = PointnetFPModule(mlp=[6, 6])
    #  test_module.cuda()
    #  from torch.autograd import gradcheck
    #  inputs = (xyz, xyz, None, xyz_feats)
    #  test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4)
    #  print(test)

    for _ in range(1):
        _, new_features = test_module(xyz, xyz_feats)
        new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1))
        print(new_features)
        print(xyz.grad)


================================================
FILE: pointnet2/utils/pointnet2_utils.py
================================================
""" PointNet++ utils
Modified by Zenn
Date: Feb 2021
"""
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)
import torch
from torch.autograd import Function
import torch.nn as nn
from pointnet2.utils import pytorch_utils as pt_utils

import pointnet2_ops._ext as _ext

if False:
    # Workaround for type hints without depending on the `typing` module
    from typing import *


class RandomDropout(nn.Module):
    def __init__(self, p=0.5, inplace=False):
        super(RandomDropout, self).__init__()
        self.p = p
        self.inplace = inplace

    def forward(self, X):
        theta = torch.Tensor(1).uniform_(0, self.p)[0]
        return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace)


class FurthestPointSampling(Function):
    @staticmethod
    def forward(ctx, xyz, npoint):
        # type: (Any, torch.Tensor, int) -> torch.Tensor
        r"""
        Uses iterative furthest point sampling to select a set of npoint features that have the largest
        minimum distance

        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor where N > npoint
        npoint : int32
            number of features in the sampled set

        Returns
        -------
        torch.Tensor
            (B, npoint) tensor containing the set
        """
        # return _ext.furthest_point_sampling(xyz, npoint)
        fps_inds = _ext.furthest_point_sampling(xyz, npoint)
        ctx.mark_non_differentiable(fps_inds)
        return fps_inds

    @staticmethod
    def backward(xyz, a=None):
        return None, None


furthest_point_sample = FurthestPointSampling.apply


class GatherOperation(Function):
    @staticmethod
    def forward(ctx, features, idx):
        # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""

        Parameters
        ----------
        features : torch.Tensor
            (B, C, N) tensor

        idx : torch.Tensor
            (B, npoint) tensor of the features to gather

        Returns
        -------
        torch.Tensor
            (B, C, npoint) tensor
        """

        _, C, N = features.size()

        ctx.for_backwards = (idx, C, N)

        return _ext.gather_points(features, idx)

    @staticmethod
    def backward(ctx, grad_out):
        idx, C, N = ctx.for_backwards

        grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
        return grad_features, None


gather_operation = GatherOperation.apply


class ThreeNN(Function):
    @staticmethod
    def forward(ctx, unknown, known):
        # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
        r"""
            Find the three nearest neighbors of unknown in known
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of known features
        known : torch.Tensor
            (B, m, 3) tensor of unknown features

        Returns
        -------
        dist : torch.Tensor
            (B, n, 3) l2 distance to the three nearest neighbors
        idx : torch.Tensor
            (B, n, 3) index of 3 nearest neighbors
        """
        dist2, idx = _ext.three_nn(unknown, known)

        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply


class ThreeInterpolate(Function):
    @staticmethod
    def forward(ctx, features, idx, weight):
        # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
        r"""
            Performs weight linear interpolation on 3 features
        Parameters
        ----------
        features : torch.Tensor
            (B, c, m) Features descriptors to be interpolated from
        idx : torch.Tensor
            (B, n, 3) three nearest neighbors of the target features in features
        weight : torch.Tensor
            (B, n, 3) weights

        Returns
        -------
        torch.Tensor
            (B, c, n) tensor of the interpolated features
        """
        B, c, m = features.size()
        n = idx.size(1)

        ctx.three_interpolate_for_backward = (idx, weight, m)

        return _ext.three_interpolate(features, idx, weight)

    @staticmethod
    def backward(ctx, grad_out):
        # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        r"""
        Parameters
        ----------
        grad_out : torch.Tensor
            (B, c, n) tensor with gradients of ouputs

        Returns
        -------
        grad_features : torch.Tensor
            (B, c, m) tensor with gradients of features

        None

        None
        """
        idx, weight, m = ctx.three_interpolate_for_backward

        grad_features = _ext.three_interpolate_grad(
            grad_out.contiguous(), idx, weight, m
        )

        return grad_features, None, None


three_interpolate = ThreeInterpolate.apply


class GroupingOperation(Function):
    @staticmethod
    def forward(ctx, features, idx):
        # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""

        Parameters
        ----------
        features : torch.Tensor
            (B, C, N) tensor of features to group
        idx : torch.Tensor
            (B, npoint, nsample) tensor containing the indicies of features to group with

        Returns
        -------
        torch.Tensor
            (B, C, npoint, nsample) tensor
        """
        B, nfeatures, nsample = idx.size()
        _, C, N = features.size()

        ctx.for_backwards = (idx, N)

        return _ext.group_points(features, idx)

    @staticmethod
    def backward(ctx, grad_out):
        # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
        r"""

        Parameters
        ----------
        grad_out : torch.Tensor
            (B, C, npoint, nsample) tensor of the gradients of the output from forward

        Returns
        -------
        torch.Tensor
            (B, C, N) gradient of the features
        None
        """
        idx, N = ctx.for_backwards

        grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)

        return grad_features, None


grouping_operation = GroupingOperation.apply


class BallQuery(Function):
    @staticmethod
    def forward(ctx, radius, nsample, xyz, new_xyz):
        # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""

        Parameters
        ----------
        radius : float
            radius of the balls
        nsample : int
            maximum number of features in the balls
        xyz : torch.Tensor
            (B, N, 3) xyz coordinates of the features
        new_xyz : torch.Tensor
            (B, npoint, 3) centers of the ball query

        Returns
        -------
        torch.Tensor
            (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        # return _ext.ball_query(new_xyz, xyz, radius, nsample)
        inds = _ext.ball_query(new_xyz, xyz, radius, nsample)
        ctx.mark_non_differentiable(inds)
        return inds

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply


class QueryAndGroup(nn.Module):
    r"""
    Groups with a ball query of radius

    Parameters
    ---------
    radius : float32
        Radius of ball
    nsample : int32
        Maximum number of features to gather in the ball
    """

    def __init__(self, radius, nsample, use_xyz=True, return_idx=False, normalize_xyz=False):
        # type: (QueryAndGroup, float, int, bool,bool,bool) -> None
        super(QueryAndGroup, self).__init__()
        self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
        self.return_idx = return_idx
        self.normalize_xyz = normalize_xyz

    def forward(self, xyz, new_xyz, features=None):
        # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            xyz coordinates of the features (B, N, 3)
        new_xyz : torch.Tensor
            centriods (B, npoint, 3)
        features : torch.Tensor
            Descriptors of the features (B, C, N)

        Returns
        -------
        new_features : torch.Tensor
            (B, 3 + C, npoint, nsample) tensor
        """

        idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
        xyz_trans = xyz.transpose(1, 2).contiguous()
        grouped_xyz = grouping_operation(xyz_trans, idx)  # (B, 3, npoint, nsample)
        grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
        if self.normalize_xyz:
            grouped_xyz /= self.radius

        if features is not None:
            grouped_features = grouping_operation(features, idx)
            if self.use_xyz:
                new_features = torch.cat(
                    [grouped_xyz, grouped_features], dim=1
                )  # (B, C + 3, npoint, nsample)
            else:
                new_features = grouped_features
        else:
            assert (
                self.use_xyz
            ), "Cannot have not features and not use xyz as a feature!"
            new_features = grouped_xyz
        if self.return_idx:
            return new_features, idx
        return new_features


class GroupAll(nn.Module):
    r"""
    Groups all features

    Parameters
    ---------
    """

    def __init__(self, use_xyz=True):
        # type: (GroupAll, bool) -> None
        super(GroupAll, self).__init__()
        self.use_xyz = use_xyz

    def forward(self, xyz, new_xyz, features=None):
        # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            xyz coordinates of the features (B, N, 3)
        new_xyz : torch.Tensor
            Ignored
        features : torch.Tensor
            Descriptors of the features (B, C, N)

        Returns
        -------
        new_features : torch.Tensor
            (B, C + 3, 1, N) tensor
        """

        grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
        if features is not None:
            grouped_features = features.unsqueeze(2)
            if self.use_xyz:
                new_features = torch.cat(
                    [grouped_xyz, grouped_features], dim=1
                )  # (B, 3 + C, 1, N)
            else:
                new_features = grouped_features
        else:
            new_features = grouped_xyz

        return new_features


def knn_point(k, points1, points2):
    """
    find for each point in points1 the knn in points2
    Args:
        k: k for kNN
        points1: B x npoint1 x d
        points2: B x npoint2 x d

    Returns:
        top_k_neareast_idx: (batch_size, npoint1, k) int32 array, indices to input points
    """
    dist_matrix = torch.cdist(points1, points2)  # B, npoint1, npoint2
    top_k_neareast_idx = torch.argsort(dist_matrix, dim=-1)[:, :, :k]  # B, npoint1, K
    top_k_neareast_idx = top_k_neareast_idx.int().contiguous()
    return top_k_neareast_idx


================================================
FILE: pointnet2/utils/pytorch_utils.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch '''
import torch
import torch.nn as nn
from typing import List, Tuple


class SharedMLP(nn.Sequential):

    def __init__(
            self,
            args: List[int],
            *,
            bn: bool = False,
            activation=nn.ReLU(inplace=True),
            preact: bool = False,
            first: bool = False,
            name: str = ""
    ):
        super().__init__()

        for i in range(len(args) - 1):
            self.add_module(
                name + 'layer{}'.format(i),
                Conv2d(
                    args[i],
                    args[i + 1],
                    bn=(not first or not preact or (i != 0)) and bn,
                    activation=activation
                    if (not first or not preact or (i != 0)) else None,
                    preact=preact
                )
            )


class _BNBase(nn.Sequential):

    def __init__(self, in_size, batch_norm=None, name=""):
        super().__init__()
        self.add_module(name + "bn", batch_norm(in_size))

        nn.init.constant_(self[0].weight, 1.0)
        nn.init.constant_(self[0].bias, 0)


class BatchNorm1d(_BNBase):

    def __init__(self, in_size: int, *, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)


class BatchNorm2d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)


class BatchNorm3d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)


class _ConvBase(nn.Sequential):

    def __init__(
            self,
            in_size,
            out_size,
            kernel_size,
            stride,
            padding,
            activation,
            bn,
            init,
            conv=None,
            batch_norm=None,
            bias=True,
            preact=False,
            name=""
    ):
        super().__init__()

        bias = bias and (not bn)
        conv_unit = conv(
            in_size,
            out_size,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias
        )
        init(conv_unit.weight)
        if bias:
            nn.init.constant_(conv_unit.bias, 0)

        if bn:
            if not preact:
                bn_unit = batch_norm(out_size)
            else:
                bn_unit = batch_norm(in_size)

        if preact:
           
Download .txt
gitextract_33cs0j0l/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── cfgs/
│   ├── BAT_CAR_NUSCENES.yaml
│   ├── BAT_Car.yaml
│   ├── BAT_Car_Waymo.yaml
│   ├── BAT_PEDESTRIAN_NUSCENES.yaml
│   ├── BAT_Pedestrian.yaml
│   ├── M2_Track_nuscene.yaml
│   ├── M2_Track_waymo.yaml
│   ├── M2_track_kitti.yaml
│   ├── P2B_Car.yaml
│   ├── P2B_Car_NuScenes.yaml
│   └── P2B_Car_Waymo.yaml
├── datasets/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── data_classes.py
│   ├── generate_waymo_sot.py
│   ├── kitti.py
│   ├── nuscenes_data.py
│   ├── points_utils.py
│   ├── sampler.py
│   ├── searchspace.py
│   ├── utils.py
│   └── waymo_data.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── backbone/
│   │   └── pointnet.py
│   ├── base_model.py
│   ├── bat.py
│   ├── head/
│   │   ├── rpn.py
│   │   └── xcorr.py
│   ├── m2track.py
│   └── p2b.py
├── pointnet2/
│   ├── __init__.py
│   └── utils/
│       ├── __init__.py
│       ├── linalg_utils.py
│       ├── pointnet2_modules.py
│       ├── pointnet2_utils.py
│       └── pytorch_utils.py
├── pretrained_models/
│   ├── bat_kitti_car.ckpt
│   ├── bat_kitti_pedestrian.ckpt
│   ├── bat_nuscenes_car.ckpt
│   ├── mmtrack_kitti_car.ckpt
│   ├── mmtrack_kitti_pedestrian.ckpt
│   └── mmtrack_nuscenes_car.ckpt
├── requirement.txt
└── utils/
    ├── __init__.py
    └── metrics.py
Download .txt
SYMBOL INDEX (286 symbols across 25 files)

FILE: datasets/__init__.py
  function get_dataset (line 8) | def get_dataset(config, type='train', **kwargs):

FILE: datasets/base_dataset.py
  class BaseDataset (line 7) | class BaseDataset:
    method __init__ (line 8) | def __init__(self, path, split, category_name="Car", **kwargs):
    method get_num_tracklets (line 15) | def get_num_tracklets(self):
    method get_num_frames_total (line 18) | def get_num_frames_total(self):
    method get_num_frames_tracklet (line 21) | def get_num_frames_tracklet(self, tracklet_id):
    method get_frames (line 24) | def get_frames(self, seq_id, frame_ids):

FILE: datasets/data_classes.py
  class PointCloud (line 11) | class PointCloud:
    method __init__ (line 13) | def __init__(self, points):
    method load_pcd_bin (line 23) | def load_pcd_bin(file_name):
    method from_file (line 34) | def from_file(cls, file_name):
    method nbr_points (line 50) | def nbr_points(self):
    method subsample (line 57) | def subsample(self, ratio):
    method remove_close (line 67) | def remove_close(self, radius):
    method translate (line 79) | def translate(self, x):
    method rotate (line 88) | def rotate(self, rot_matrix):
    method transform (line 96) | def transform(self, transf_matrix):
    method convertToPytorch (line 105) | def convertToPytorch(self):
    method fromPytorch (line 113) | def fromPytorch(cls, pytorchTensor):
    method normalize (line 123) | def normalize(self, wlh):
  class Box (line 128) | class Box:
    method __init__ (line 131) | def __init__(self, center, size, orientation, label=np.nan, score=np.n...
    method __eq__ (line 156) | def __eq__(self, other):
    method __repr__ (line 167) | def __repr__(self):
    method encode (line 177) | def encode(self):
    method decode (line 185) | def decode(cls, data):
    method rotation_matrix (line 195) | def rotation_matrix(self):
    method translate (line 202) | def translate(self, x):
    method rotate (line 210) | def rotate(self, quaternion):
    method transform (line 220) | def transform(self, transf_matrix):
    method corners (line 226) | def corners(self, wlh_factor=1.0):
    method bottom_corners (line 252) | def bottom_corners(self):

FILE: datasets/generate_waymo_sot.py
  function lood_pickle (line 15) | def lood_pickle(root):
  function generate_waymo_data (line 21) | def generate_waymo_data(root, cla, split):

FILE: datasets/kitti.py
  class kittiDataset (line 18) | class kittiDataset(base_dataset.BaseDataset):
    method __init__ (line 19) | def __init__(self, path, split, category_name="Car", **kwargs):
    method _build_scene_list (line 36) | def _build_scene_list(split):
    method _load_data (line 58) | def _load_data(self):
    method get_num_scenes (line 79) | def get_num_scenes(self):
    method get_num_tracklets (line 82) | def get_num_tracklets(self):
    method get_num_frames_total (line 85) | def get_num_frames_total(self):
    method get_num_frames_tracklet (line 88) | def get_num_frames_tracklet(self, tracklet_id):
    method _build_tracklet_anno (line 91) | def _build_tracklet_anno(self):
    method get_frames (line 130) | def get_frames(self, seq_id, frame_ids):
    method _get_frame_from_anno (line 139) | def _get_frame_from_anno(self, anno):
    method _read_calib_file (line 192) | def _read_calib_file(filepath):

FILE: datasets/nuscenes_data.py
  class NuScenesDataset (line 58) | class NuScenesDataset(base_dataset.BaseDataset):
    method __init__ (line 59) | def __init__(self, path, split, category_name="Car", version='v1.0-tra...
    method filter_instance (line 71) | def filter_instance(self, split, category_name=None, min_points=-1):
    method _build_tracklet_anno (line 93) | def _build_tracklet_anno(self):
    method _load_data (line 115) | def _load_data(self):
    method get_num_tracklets (line 136) | def get_num_tracklets(self):
    method get_num_frames_total (line 139) | def get_num_frames_total(self):
    method get_num_frames_tracklet (line 142) | def get_num_frames_tracklet(self, tracklet_id):
    method get_frames (line 145) | def get_frames(self, seq_id, frame_ids):
    method _get_frame_from_anno_data (line 154) | def _get_frame_from_anno_data(self, anno):

FILE: datasets/points_utils.py
  function random_choice (line 11) | def random_choice(num_samples, size, replacement=False, seed=None):
  function regularize_pc (line 24) | def regularize_pc(points, sample_size, seed=None):
  function getOffsetBB (line 43) | def getOffsetBB(box, offset, degrees=True, use_z=False, limit_box=True, ...
  function getModel (line 88) | def getModel(PCs, boxes, offset=0, scale=1.0, normalize=False):
  function cropAndCenterPC (line 103) | def cropAndCenterPC(PC, box, offset=0, scale=1.0, normalize=False):
  function get_point_to_box_distance (line 127) | def get_point_to_box_distance(pc, box, wlh_factor=1.0):
  function crop_pc_axis_aligned (line 146) | def crop_pc_axis_aligned(PC, box, offset=0, scale=1.0, return_mask=False):
  function crop_pc_oriented (line 174) | def crop_pc_oriented(PC, box, offset=0, scale=1.0, return_mask=False):
  function generate_subwindow (line 218) | def generate_subwindow(pc, sample_bb, scale, offset=2, oriented=True):
  function transform_box (line 253) | def transform_box(box, ref_box, inplace=False):
  function transform_pc (line 261) | def transform_pc(pc, ref_box, inplace=False):
  function get_in_box_mask (line 269) | def get_in_box_mask(PC, box):
  function apply_transform (line 299) | def apply_transform(in_box_pc, box, translation, rotation, flip_x, flip_...
  function apply_augmentation (line 348) | def apply_augmentation(pc, box, wlh_factor=1.25):
  function roty_batch_tensor (line 364) | def roty_batch_tensor(t):
  function rotz_batch_tensor (line 377) | def rotz_batch_tensor(t):
  function get_offset_points_tensor (line 390) | def get_offset_points_tensor(points, ref_box_params, offset_box_params):
  function get_offset_box_tensor (line 418) | def get_offset_box_tensor(ref_box_params, offset_box_params):
  function remove_transform_points_tensor (line 437) | def remove_transform_points_tensor(points, ref_box_params):
  function np_to_torch_tensor (line 454) | def np_to_torch_tensor(data, device=None):

FILE: datasets/sampler.py
  function no_processing (line 12) | def no_processing(data, *args):
  function siamese_processing (line 16) | def siamese_processing(data, config, template_transform=None, search_tra...
  function motion_processing (line 82) | def motion_processing(data, config, template_transform=None, search_tran...
  class PointTrackingSampler (line 183) | class PointTrackingSampler(torch.utils.data.Dataset):
    method __init__ (line 184) | def __init__(self, dataset, random_sample, sample_per_epoch=10000, pro...
    method get_anno_index (line 206) | def get_anno_index(self, index):
    method get_candidate_index (line 209) | def get_candidate_index(self, index):
    method __len__ (line 212) | def __len__(self):
    method __getitem__ (line 218) | def __getitem__(self, index):
  class TestTrackingSampler (line 246) | class TestTrackingSampler(torch.utils.data.Dataset):
    method __init__ (line 247) | def __init__(self, dataset, config=None, **kwargs):
    method __len__ (line 253) | def __len__(self):
    method __getitem__ (line 256) | def __getitem__(self, index):
  class MotionTrackingSampler (line 262) | class MotionTrackingSampler(PointTrackingSampler):
    method __init__ (line 263) | def __init__(self, dataset, config=None, **kwargs):
    method __getitem__ (line 267) | def __getitem__(self, index):

FILE: datasets/searchspace.py
  class SearchSpace (line 6) | class SearchSpace(object):
    method reset (line 8) | def reset(self):
    method sample (line 11) | def sample(self):
    method addData (line 14) | def addData(self, data, score):
  class ExhaustiveSearch (line 18) | class ExhaustiveSearch(SearchSpace):
    method __init__ (line 20) | def __init__(self,
    method reset (line 42) | def reset(self):
    method sample (line 45) | def sample(self, n=0):
  class ParticleFiltering (line 49) | class ParticleFiltering(SearchSpace):
    method __init__ (line 50) | def __init__(self, bnd=[1, 1, 10]):
    method sample (line 54) | def sample(self, n=10):
    method addData (line 71) | def addData(self, data, score):
    method reset (line 76) | def reset(self):
  class KalmanFiltering (line 85) | class KalmanFiltering(SearchSpace):
    method __init__ (line 86) | def __init__(self, bnd=[1, 1, 10]):
    method sample (line 90) | def sample(self, n=10):
    method addData (line 93) | def addData(self, data, score):
    method reset (line 100) | def reset(self):
  class GaussianMixtureModel (line 110) | class GaussianMixtureModel(SearchSpace):
    method __init__ (line 112) | def __init__(self, n_comp=5, dim=3):
    method sample (line 116) | def sample(self, n=10):
    method addData (line 157) | def addData(self, data, score):
    method reset (line 173) | def reset(self, n_comp=5):

FILE: datasets/utils.py
  function roty (line 10) | def roty(t):
  function get_3d_box (line 18) | def get_3d_box(box_size, heading_angle, center):
  function write_ply (line 39) | def write_ply(verts, colors, indices, output_file):
  function box2obj (line 66) | def box2obj(box, objname):
  function write_bbox (line 79) | def write_bbox(corners, mode, output_file):
  function write_obj (line 209) | def write_obj(points, file, rgb=False):

FILE: datasets/waymo_data.py
  class WaymoDataset (line 21) | class WaymoDataset(base_dataset.BaseDataset):
    method __init__ (line 22) | def __init__(self, path, split, category_name="VEHICLE", **kwargs):
    method _load_data (line 48) | def _load_data(self):
    method get_num_scenes (line 77) | def get_num_scenes(self):
    method get_num_tracklets (line 80) | def get_num_tracklets(self):
    method get_num_frames_total (line 83) | def get_num_frames_total(self):
    method get_num_frames_tracklet (line 86) | def get_num_frames_tracklet(self, tracklet_id):
    method _build_tracklet_anno (line 89) | def _build_tracklet_anno(self):
    method get_frames (line 109) | def get_frames(self, seq_id, frame_ids):
    method _get_frame_from_anno (line 118) | def _get_frame_from_anno(self, anno, track_id=None, check=False):
    method veh_pos_to_transform (line 170) | def veh_pos_to_transform(veh_pos):

FILE: main.py
  function load_yaml (line 23) | def load_yaml(file_name):
  function parse_config (line 32) | def parse_config():

FILE: models/__init__.py
  function get_model (line 19) | def get_model(name):

FILE: models/backbone/pointnet.py
  class Pointnet_Backbone (line 12) | class Pointnet_Backbone(nn.Module):
    method __init__ (line 28) | def __init__(self, use_fps=False, normalize_xyz=False, return_intermed...
    method _break_up_pc (line 60) | def _break_up_pc(self, pc):
    method forward (line 66) | def forward(self, pointcloud, numpoints):
  class MiniPointNet (line 91) | class MiniPointNet(nn.Module):
    method __init__ (line 93) | def __init__(self, input_channel, per_point_mlp, hidden_mlp, output_si...
    method forward (line 128) | def forward(self, x):
  class SegPointNet (line 144) | class SegPointNet(nn.Module):
    method __init__ (line 146) | def __init__(self, input_channel, per_point_mlp1, per_point_mlp2, outp...
    method forward (line 184) | def forward(self, x):

FILE: models/base_model.py
  class BaseModel (line 17) | class BaseModel(pl.LightningModule):
    method __init__ (line 18) | def __init__(self, config=None, **kwargs):
    method configure_optimizers (line 28) | def configure_optimizers(self):
    method compute_loss (line 38) | def compute_loss(self, data, output):
    method build_input_dict (line 41) | def build_input_dict(self, sequence, frame_id, results_bbs, **kwargs):
    method evaluate_one_sample (line 44) | def evaluate_one_sample(self, data_dict, ref_box):
    method evaluate_one_sequence (line 59) | def evaluate_one_sequence(self, sequence):
    method validation_step (line 88) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 97) | def validation_epoch_end(self, outputs):
    method test_step (line 103) | def test_step(self, batch, batch_idx):
    method test_epoch_end (line 113) | def test_epoch_end(self, outputs):
  class MatchingBaseModel (line 120) | class MatchingBaseModel(BaseModel):
    method compute_loss (line 122) | def compute_loss(self, data, output):
    method generate_template (line 166) | def generate_template(self, sequence, current_frame_id, results_bbs):
    method generate_search_area (line 197) | def generate_search_area(self, sequence, current_frame_id, results_bbs):
    method prepare_input (line 220) | def prepare_input(self, template_pc, search_pc, template_box, *args, *...
    method build_input_dict (line 240) | def build_input_dict(self, sequence, frame_id, results_bbs, **kwargs):
  class MotionBaseModel (line 250) | class MotionBaseModel(BaseModel):
    method __init__ (line 251) | def __init__(self, config, **kwargs):
    method build_input_dict (line 255) | def build_input_dict(self, sequence, frame_id, results_bbs):

FILE: models/bat.py
  class BAT (line 17) | class BAT(base_model.MatchingBaseModel):
    method __init__ (line 18) | def __init__(self, config=None, **kwargs):
    method prepare_input (line 41) | def prepare_input(self, template_pc, search_pc, template_box):
    method compute_loss (line 57) | def compute_loss(self, data, output):
    method forward (line 67) | def forward(self, input_dict):
    method training_step (line 114) | def training_step(self, batch, batch_idx):

FILE: models/head/rpn.py
  class P2BVoteNetRPN (line 12) | class P2BVoteNetRPN(nn.Module):
    method __init__ (line 14) | def __init__(self, feature_channel, vote_channel=256, num_proposal=64,...
    method forward (line 41) | def forward(self, xyz, feature):

FILE: models/head/xcorr.py
  class BaseXCorr (line 10) | class BaseXCorr(nn.Module):
    method __init__ (line 11) | def __init__(self, in_channel, hidden_channel, out_channel):
  class P2B_XCorr (line 20) | class P2B_XCorr(BaseXCorr):
    method __init__ (line 21) | def __init__(self, feature_channel, hidden_channel, out_channel):
    method forward (line 25) | def forward(self, template_feature, search_feature, template_xyz):
  class BoxAwareXCorr (line 56) | class BoxAwareXCorr(BaseXCorr):
    method __init__ (line 57) | def __init__(self, feature_channel, hidden_channel, out_channel, k=8, ...
    method forward (line 67) | def forward(self, template_feature, search_feature, template_xyz,

FILE: models/m2track.py
  class M2TRACK (line 17) | class M2TRACK(base_model.MotionBaseModel):
    method __init__ (line 18) | def __init__(self, config, **kwargs):
    method forward (line 73) | def forward(self, input_dict):
    method compute_loss (line 153) | def compute_loss(self, data, output):
    method training_step (line 233) | def training_step(self, batch, batch_idx):

FILE: models/p2b.py
  class P2B (line 13) | class P2B(base_model.MatchingBaseModel):
    method __init__ (line 14) | def __init__(self, config=None, **kwargs):
    method forward (line 28) | def forward(self, input_dict):
    method training_step (line 61) | def training_step(self, batch, batch_idx):

FILE: pointnet2/utils/linalg_utils.py
  function pdist2 (line 15) | def pdist2(X, Z=None, order=PDist2Order.d_second):
  function pdist2_slow (line 66) | def pdist2_slow(X, Z=None):

FILE: pointnet2/utils/pointnet2_modules.py
  class _PointnetSAModuleBase (line 24) | class _PointnetSAModuleBase(nn.Module):
    method __init__ (line 25) | def __init__(self, use_fps=False):
    method forward (line 31) | def forward(self, xyz, features, npoint, return_idx=False):
  class PointnetSAModuleMSG (line 82) | class PointnetSAModuleMSG(_PointnetSAModuleBase):
    method __init__ (line 99) | def __init__(self, radii, nsamples, mlps, bn=True, use_xyz=True, use_f...
  class PointnetSAModule (line 120) | class PointnetSAModule(PointnetSAModuleMSG):
    method __init__ (line 137) | def __init__(
  class PointnetFPModule (line 152) | class PointnetFPModule(nn.Module):
    method __init__ (line 163) | def __init__(self, mlp, bn=True):
    method forward (line 168) | def forward(self, unknown, known, unknow_feats, known_feats):
  class FlowEmbedding (line 215) | class FlowEmbedding(nn.Module):
    method __init__ (line 218) | def __init__(self, radius, nsample, in_channel, mlp, pooling='max', co...
    method forward (line 234) | def forward(self, xyz1, xyz2, feature1, feature2):
  class PointNetSetUpConv (line 272) | class PointNetSetUpConv(nn.Module):
    method __init__ (line 273) | def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2,...
    method forward (line 296) | def forward(self, xyz1, xyz2, feature1, feature2):

FILE: pointnet2/utils/pointnet2_utils.py
  class RandomDropout (line 24) | class RandomDropout(nn.Module):
    method __init__ (line 25) | def __init__(self, p=0.5, inplace=False):
    method forward (line 30) | def forward(self, X):
  class FurthestPointSampling (line 35) | class FurthestPointSampling(Function):
    method forward (line 37) | def forward(ctx, xyz, npoint):
    method backward (line 61) | def backward(xyz, a=None):
  class GatherOperation (line 68) | class GatherOperation(Function):
    method forward (line 70) | def forward(ctx, features, idx):
    method backward (line 95) | def backward(ctx, grad_out):
  class ThreeNN (line 105) | class ThreeNN(Function):
    method forward (line 107) | def forward(ctx, unknown, known):
    method backward (line 130) | def backward(ctx, a=None, b=None):
  class ThreeInterpolate (line 137) | class ThreeInterpolate(Function):
    method forward (line 139) | def forward(ctx, features, idx, weight):
    method backward (line 165) | def backward(ctx, grad_out):
  class GroupingOperation (line 194) | class GroupingOperation(Function):
    method forward (line 196) | def forward(ctx, features, idx):
    method backward (line 220) | def backward(ctx, grad_out):
  class BallQuery (line 245) | class BallQuery(Function):
    method forward (line 247) | def forward(ctx, radius, nsample, xyz, new_xyz):
    method backward (line 273) | def backward(ctx, a=None):
  class QueryAndGroup (line 280) | class QueryAndGroup(nn.Module):
    method __init__ (line 292) | def __init__(self, radius, nsample, use_xyz=True, return_idx=False, no...
    method forward (line 299) | def forward(self, xyz, new_xyz, features=None):
  class GroupAll (line 342) | class GroupAll(nn.Module):
    method __init__ (line 350) | def __init__(self, use_xyz=True):
    method forward (line 355) | def forward(self, xyz, new_xyz, features=None):
  function knn_point (line 388) | def knn_point(k, points1, points2):

FILE: pointnet2/utils/pytorch_utils.py
  class SharedMLP (line 12) | class SharedMLP(nn.Sequential):
    method __init__ (line 14) | def __init__(
  class _BNBase (line 40) | class _BNBase(nn.Sequential):
    method __init__ (line 42) | def __init__(self, in_size, batch_norm=None, name=""):
  class BatchNorm1d (line 50) | class BatchNorm1d(_BNBase):
    method __init__ (line 52) | def __init__(self, in_size: int, *, name: str = ""):
  class BatchNorm2d (line 56) | class BatchNorm2d(_BNBase):
    method __init__ (line 58) | def __init__(self, in_size: int, name: str = ""):
  class BatchNorm3d (line 62) | class BatchNorm3d(_BNBase):
    method __init__ (line 64) | def __init__(self, in_size: int, name: str = ""):
  class _ConvBase (line 68) | class _ConvBase(nn.Sequential):
    method __init__ (line 70) | def __init__(
  class Conv1d (line 124) | class Conv1d(_ConvBase):
    method __init__ (line 126) | def __init__(
  class Conv2d (line 158) | class Conv2d(_ConvBase):
    method __init__ (line 160) | def __init__(
  class Conv3d (line 192) | class Conv3d(_ConvBase):
    method __init__ (line 194) | def __init__(
  class FC (line 226) | class FC(nn.Sequential):
    method __init__ (line 228) | def __init__(
  function set_bn_momentum_default (line 263) | def set_bn_momentum_default(bn_momentum):
  class BNMomentumScheduler (line 272) | class BNMomentumScheduler(object):
    method __init__ (line 274) | def __init__(
    method step (line 292) | def step(self, epoch=None):
  class Seq (line 300) | class Seq(nn.Sequential):
    method __init__ (line 302) | def __init__(self, input_channels):
    method conv1d (line 307) | def conv1d(self,
    method conv2d (line 341) | def conv2d(self,
    method conv3d (line 375) | def conv3d(self,
    method fc (line 409) | def fc(self,
    method dropout (line 432) | def dropout(self, p=0.5):
    method maxpool2d (line 439) | def maxpool2d(self,

FILE: utils/metrics.py
  class AverageMeter (line 8) | class AverageMeter(object):
    method __init__ (line 11) | def __init__(self):
    method reset (line 14) | def reset(self):
    method update (line 20) | def update(self, val, n=1):
  function estimateAccuracy (line 27) | def estimateAccuracy(box_a, box_b, dim=3, up_axis=(0, -1, 0)):
  function fromBoxToPoly (line 36) | def fromBoxToPoly(box, up_axis=(0, -1, 0)):
  function estimateOverlap (line 49) | def estimateOverlap(box_a, box_b, dim=2, up_axis=(0, -1, 0)):
  class TorchPrecision (line 75) | class TorchPrecision(Metric):
    method __init__ (line 78) | def __init__(self, n=21, max_accuracy=2, dist_sync_on_step=False):
    method value (line 84) | def value(self, accs):
    method update (line 91) | def update(self, val):
    method compute (line 94) | def compute(self):
  class TorchSuccess (line 101) | class TorchSuccess(Metric):
    method __init__ (line 104) | def __init__(self, n=21, max_overlap=1, dist_sync_on_step=False):
    method value (line 110) | def value(self, overlaps):
    method compute (line 117) | def compute(self):
    method update (line 124) | def update(self, val):
Condensed preview — 50 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (217K chars).
[
  {
    "path": ".gitattributes",
    "chars": 66,
    "preview": "# Auto detect text files and perform LF normalization\n* text=auto\n"
  },
  {
    "path": ".gitignore",
    "chars": 58,
    "preview": ".DS_Store\n.idea/\n*/.DS_Store\n*.pyc\nevents.*\nlightning_logs"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2021 Kangel Zenn\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 8831,
    "preview": "# Open3DSOT\nA general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning.\n\nTh"
  },
  {
    "path": "cfgs/BAT_CAR_NUSCENES.yaml",
    "chars": 1249,
    "preview": "#data\ndataset: nuscenes\npath: #put data root here\nversion: v1.0-trainval\ncategory_name: Car\nsearch_bb_scale: 1.25\nsearch"
  },
  {
    "path": "cfgs/BAT_Car.yaml",
    "chars": 1267,
    "preview": "#data\ndataset: kitti\npath:  #put data root here\ncategory_name: Car # [Car, Van, Pedestrian, Cyclist, All]\nsearch_bb_scal"
  },
  {
    "path": "cfgs/BAT_Car_Waymo.yaml",
    "chars": 1307,
    "preview": "#data\ndataset: waymo\npath:  #put the root of the dataset here\ncategory_name: Vehicle # [Vehicle, Pedestrian, Cyclist]\nse"
  },
  {
    "path": "cfgs/BAT_PEDESTRIAN_NUSCENES.yaml",
    "chars": 1235,
    "preview": "#data\ndataset: nuscenes\npath: #put data root here\nversion: v1.0-trainval\ncategory_name: Pedestrian\nsearch_bb_scale: 1.25"
  },
  {
    "path": "cfgs/BAT_Pedestrian.yaml",
    "chars": 1286,
    "preview": "#data\ndataset: kitti\npath: #put the root of the dataset here\ncategory_name: Pedestrian # [Car, Van, Pedestrian, Cyclist,"
  },
  {
    "path": "cfgs/M2_Track_nuscene.yaml",
    "chars": 797,
    "preview": "#data\ndataset: nuscenes\npath:  #put data root here\nversion: v1.0-trainval\ncategory_name: Car\nbb_scale: 1.25\nbb_offset: 2"
  },
  {
    "path": "cfgs/M2_Track_waymo.yaml",
    "chars": 817,
    "preview": "#data\ndataset: waymo\npath: #put data root here\ncategory_name: Vehicle # [Vehicle, Pedestrian, Cyclist]\nbb_scale: 1.25\nbb"
  },
  {
    "path": "cfgs/M2_track_kitti.yaml",
    "chars": 792,
    "preview": "#data\ndataset: kitti\npath: #put data root here\ncategory_name: Pedestrian # [Car, Van, Pedestrian, Cyclist, All]\nbb_scale"
  },
  {
    "path": "cfgs/P2B_Car.yaml",
    "chars": 1163,
    "preview": "#data\ndataset: kitti\npath:  #put the root of the dataset here\ncategory_name: Car # [Car, Van, Pedestrian, Cyclist, All]\n"
  },
  {
    "path": "cfgs/P2B_Car_NuScenes.yaml",
    "chars": 1190,
    "preview": "#data\ndataset: nuscenes\npath: #put data root here\nversion: v1.0-trainval\ncategory_name: Car # [Car, Van, Pedestrian, Cyc"
  },
  {
    "path": "cfgs/P2B_Car_Waymo.yaml",
    "chars": 1307,
    "preview": "#data\ndataset: waymo\npath:  #put the root of the dataset here\ncategory_name: Vehicle # [Vehicle, Pedestrian, Cyclist]\nse"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 2590,
    "preview": "\"\"\"\n___init__.py\nCreated by zenn at 2021/7/18 15:50\n\"\"\"\nfrom datasets import kitti, sampler, nuscenes_data, waymo_data\n\n"
  },
  {
    "path": "datasets/base_dataset.py",
    "chars": 608,
    "preview": "\"\"\" \nbase_dataset.py\nCreated by zenn at 2021/9/1 22:16\n\"\"\"\n\n\nclass BaseDataset:\n    def __init__(self, path, split, cate"
  },
  {
    "path": "datasets/data_classes.py",
    "chars": 9674,
    "preview": "# nuScenes dev-kit.\n# Code written by Oscar Beijbom, 2018.\n# Licensed under the Creative Commons [see licence.txt]\n\n#fro"
  },
  {
    "path": "datasets/generate_waymo_sot.py",
    "chars": 1823,
    "preview": "#!/usr/bin/env python\n# encoding: utf-8\n'''\n@author: Xu Yan\n@file: generate_waymo_sot.py\n@time: 2021/6/17 13:17\n'''\nimpo"
  },
  {
    "path": "datasets/kitti.py",
    "chars": 8595,
    "preview": "# Created by zenn at 2021/4/27\n\nimport copy\nimport random\n\nfrom torch.utils.data import Dataset\nfrom datasets.data_class"
  },
  {
    "path": "datasets/nuscenes_data.py",
    "chars": 8451,
    "preview": "\"\"\"\nnuscenes.py\nCreated by zenn at 2021/9/1 15:05\n\"\"\"\nimport os\n\nimport numpy as np\nimport pickle\nimport nuscenes\nfrom n"
  },
  {
    "path": "datasets/points_utils.py",
    "chars": 14442,
    "preview": "import nuscenes.utils.geometry_utils\nimport torch\nimport os\nimport copy\nimport numpy as np\nfrom pyquaternion import Quat"
  },
  {
    "path": "datasets/sampler.py",
    "chars": 14078,
    "preview": "# Created by zenn at 2021/4/27\n\nimport numpy as np\nimport torch\nfrom easydict import EasyDict\nfrom nuscenes.utils import"
  },
  {
    "path": "datasets/searchspace.py",
    "chars": 6053,
    "preview": "import numpy as np\nfrom pomegranate import MultivariateGaussianDistribution, GeneralMixtureModel\nimport logging\n\n\nclass "
  },
  {
    "path": "datasets/utils.py",
    "chars": 8057,
    "preview": "#!/usr/bin/env python\n# encoding: utf-8\n'''\n@author: Xu Yan\n@file: utils.py\n@time: 2021/10/21 21:45\n'''\nimport numpy as "
  },
  {
    "path": "datasets/waymo_data.py",
    "chars": 8388,
    "preview": "# Created by Xu Yan at 2021/10/17\n\nimport copy\nimport random\n\nfrom torch.utils.data import Dataset\nfrom datasets.data_cl"
  },
  {
    "path": "main.py",
    "chars": 3829,
    "preview": "\"\"\"\nmain.py\nCreated by zenn at 2021/7/18 15:08\n\"\"\"\nimport pytorch_lightning as pl\nimport argparse\n\nimport pytorch_lightn"
  },
  {
    "path": "models/__init__.py",
    "chars": 481,
    "preview": "\"\"\" \n__init__.py\nCreated by zenn at 2021/7/15 21:40\n\"\"\"\n\nimport importlib\n# import pkgutil\n# import os\n# import inspect\n"
  },
  {
    "path": "models/backbone/pointnet.py",
    "chars": 7011,
    "preview": "\"\"\"\npointnet.py\nCreated by zenn at 2021/5/9 13:41\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nfrom pointnet2.utils.pointnet"
  },
  {
    "path": "models/base_model.py",
    "chars": 15126,
    "preview": "\"\"\" \nbaseModel.py\nCreated by zenn at 2021/5/9 14:40\n\"\"\"\n\nimport torch\nfrom easydict import EasyDict\nimport pytorch_light"
  },
  {
    "path": "models/bat.py",
    "chars": 8411,
    "preview": "\"\"\" \nbat.py\nCreated by zenn at 2021/7/21 14:16\n\"\"\"\n\nimport torch\nfrom torch import nn\nfrom models.backbone.pointnet impo"
  },
  {
    "path": "models/head/rpn.py",
    "chars": 2429,
    "preview": "\"\"\" \nrpn.py\nCreated by zenn at 2021/5/8 20:55\n\"\"\"\nimport torch\nfrom torch import nn\nfrom pointnet2.utils import pytorch_"
  },
  {
    "path": "models/head/xcorr.py",
    "chars": 4652,
    "preview": "# Created by zenn at 2021/5/8\nimport torch\nfrom torch import nn\nfrom pointnet2.utils import pytorch_utils as pt_utils\nfr"
  },
  {
    "path": "models/m2track.py",
    "chars": 12861,
    "preview": "\"\"\"\nm2track.py\nCreated by zenn at 2021/11/24 13:10\n\"\"\"\nfrom datasets import points_utils\nfrom models import base_model\nf"
  },
  {
    "path": "models/p2b.py",
    "chars": 4894,
    "preview": "\"\"\" \np2b.py\nCreated by zenn at 2021/5/9 13:47\n\"\"\"\n\nfrom torch import nn\nfrom models.backbone.pointnet import Pointnet_Ba"
  },
  {
    "path": "pointnet2/__init__.py",
    "chars": 288,
    "preview": "from __future__ import (\n    division,\n    absolute_import,\n    with_statement,\n    print_function,\n    unicode_literals"
  },
  {
    "path": "pointnet2/utils/__init__.py",
    "chars": 186,
    "preview": "from __future__ import (\n    division,\n    absolute_import,\n    with_statement,\n    print_function,\n    unicode_literals"
  },
  {
    "path": "pointnet2/utils/linalg_utils.py",
    "chars": 2335,
    "preview": "from __future__ import (\n    division,\n    absolute_import,\n    with_statement,\n    print_function,\n    unicode_literals"
  },
  {
    "path": "pointnet2/utils/pointnet2_modules.py",
    "chars": 13131,
    "preview": "\"\"\" PointNet++ Layers\nModified by Zenn\nDate: Feb 2021\n\"\"\"\nfrom __future__ import (\n    division,\n    absolute_import,\n  "
  },
  {
    "path": "pointnet2/utils/pointnet2_utils.py",
    "chars": 11171,
    "preview": "\"\"\" PointNet++ utils\nModified by Zenn\nDate: Feb 2021\n\"\"\"\nfrom __future__ import (\n    division,\n    absolute_import,\n   "
  },
  {
    "path": "pointnet2/utils/pytorch_utils.py",
    "chars": 12048,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
  },
  {
    "path": "requirement.txt",
    "chars": 346,
    "preview": "protobuf==3.19\neasydict==1.9\nnumpy==1.22\npandas==1.1.5\ngit+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=poin"
  },
  {
    "path": "utils/__init__.py",
    "chars": 56,
    "preview": "\"\"\" \n__init__.py\nCreated by zenn at 2021/7/16 14:13\n\"\"\"\n"
  },
  {
    "path": "utils/metrics.py",
    "chars": 3908,
    "preview": "import numpy as np\nimport torch\nimport torchmetrics.utilities.data\nfrom shapely.geometry import Polygon\nfrom torchmetric"
  }
]

// ... and 6 more files (download for full content)

About this extraction

This page contains the full source code of the Ghostish/Open3DSOT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 50 files (128.7 MB), approximately 52.8k tokens, and a symbol index with 286 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!