Full Code of Nicholasli1995/EgoNet for AI

master 13e3758388ab cached
56 files
467.4 KB
121.0k tokens
428 symbols
1 requests
Download .txt
Showing preview only (489K chars total). Download the full file or copy to clipboard to get everything.
Repository: Nicholasli1995/EgoNet
Branch: master
Commit: 13e3758388ab
Files: 56
Total size: 467.4 KB

Directory structure:
gitextract_zura5gg8/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── KITTI_inference:demo.yml
│   ├── KITTI_inference:test_submission.yml
│   ├── KITTI_train_IGRs.yml
│   ├── KITTI_train_IGRs_Ped.yml
│   └── KITTI_train_lifting.yml
├── docs/
│   ├── demo.md
│   ├── inference.md
│   ├── preparation.md
│   ├── spec-list.txt
│   └── training.md
├── libs/
│   ├── arguments/
│   │   ├── __init__.py
│   │   └── parse.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── format.py
│   │   ├── img_proc.py
│   │   ├── transformation.py
│   │   └── utils.py
│   ├── dataset/
│   │   ├── KITTI/
│   │   │   ├── __init__.py
│   │   │   └── car_instance.py
│   │   ├── __init__.py
│   │   ├── basic/
│   │   │   ├── __init__.py
│   │   │   └── basic_classes.py
│   │   └── normalization/
│   │       ├── __init__.py
│   │       └── operations.py
│   ├── logger/
│   │   ├── __init__.py
│   │   └── logger.py
│   ├── loss/
│   │   ├── __init__.py
│   │   └── function.py
│   ├── metric/
│   │   └── criterions.py
│   ├── model/
│   │   ├── FCmodel.py
│   │   ├── __init__.py
│   │   ├── egonet.py
│   │   └── heatmapModel/
│   │       ├── __init__.py
│   │       ├── hrnet.py
│   │       └── resnet.py
│   ├── optimizer/
│   │   ├── __init__.py
│   │   └── optimizer.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── accuracy.py
│   │   └── trainer.py
│   └── visualization/
│       ├── __init__.py
│       ├── debug.py
│       ├── egonet_utils.py
│       └── points.py
└── tools/
    ├── inference.py
    ├── inference_legacy.py
    ├── kitti-eval/
    │   ├── README.md
    │   ├── evaluate_object_3d.cpp
    │   ├── evaluate_object_3d_offline.cpp
    │   ├── evaluate_object_3d_offline_r40.cpp
    │   └── mail.h
    ├── train_IGRs.py
    └── train_lifting.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
#/*
**/__pycache__
.spyproject/
*.log
*.ini
*.bak
*.pth
*.csv
*.jpg
*.png
*.pdf
/tools/kitti-eval/evaluate_object_3d_offline
/outputs



================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 Shichao Li

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/exploring-intermediate-representation-for/vehicle-pose-estimation-on-kitti-cars-hard)](https://paperswithcode.com/sota/vehicle-pose-estimation-on-kitti-cars-hard?p=exploring-intermediate-representation-for)
# EgoNet
Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation". This repo includes an implementation that performs vehicle orientation estimation on the KITTI dataset from a single RGB image. 

News:

(2022-??-??): v-1.1 will be released which include pre-trained models for other object classes (Pedestrian and Cyclist in KITTI).

(2021-08-16): v-1.0 is released. The training documentation is added.

(2021-06-21): v-0.9 (beta version) is released. **The inference utility is here!** For Q&A, go to [discussions](https://github.com/Nicholasli1995/EgoNet/discussions). If you believe there is a technical problem, submit to [issues](https://github.com/Nicholasli1995/EgoNet/issues). 

(2021-06-16): This repo is under final code cleaning and documentation preparation. Stay tuned and come back in a week!

**Check our 5-min video ([Youtube](https://www.youtube.com/watch?v=isKo0F3MU68), [爱奇艺](https://www.iqiyi.com/v_y6lrdy33kg.html)) for an introduction.**

**中文详解**:[哔哩哔哩](https://www.bilibili.com/video/BV1jP4y1t7ee)
<p align="center">
  <img src="https://github.com/Nicholasli1995/EgoNet/blob/master/imgs/teaser.jpg"  width="830" height="200" />
</p>

## Run a demo with a one-line command!
Check instructions [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/demo.md).
<p align="center">
  <img src="https://github.com/Nicholasli1995/EgoNet/blob/master/imgs/Ego-Net_demo.png" height="175"/>
  <img src="https://github.com/Nicholasli1995/EgoNet/blob/master/imgs/Ego-Net_demo.gif" height="175"/>
</p>

## Performance: AP<sup>BEV</sup>@R<sub>40</sub> on KITTI val set for Car (monocular RGB)
The validation results in the paper was based on R<sub>11</sub>, the results using R<sub>40</sub> are attached here.
| Method                    | Reference|Easy|Moderate|Hard|
| ------------------------- | ---------------| --------------| --------------| --------------| 
|[M3D-RPN](https://arxiv.org/abs/1907.06038)|ICCV 2019|20.85| 15.62| 11.88|
|[MonoDIS](https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf)|ICCV 2019|18.45 |12.58 |10.66|
|[MonoPair](https://arxiv.org/abs/2003.00504)|CVPR 2020|24.12| 18.17| 15.76|
|[D4LCN](https://github.com/dingmyu/D4LCN)|CVPR 2020|31.53 |22.58  |17.87|
|[Kinematic3D](https://arxiv.org/abs/2007.09548)|ECCV 2020|27.83| 19.72| 15.10|
|[GrooMeD-NMS](https://github.com/abhi1kumar/groomed_nms)|CVPR 2021 |27.38|19.75|15.92|
|[MonoDLE](https://github.com/xinzhuma/monodle)|CVPR 2021|24.97| 19.33| 17.01|
|Ours (@R<sub>11</sub>)           |CVPR 2021 |**33.60**|**25.38**|**22.80**|
|Ours (@R<sub>40</sub>)           |CVPR 2021 |**34.31**|**24.80**|**20.16**|

## Performance: AOS@R<sub>40</sub> on KITTI test set for Car (RGB)

| Method                    | Reference|Configuration|Easy|Moderate|Hard|
| ------------------------- | ---------------| --------------| --------------| --------------| --------------| 
|[M3D-RPN](https://arxiv.org/abs/1907.06038)|ICCV 2019|Monocular|88.38 |82.81| 67.08|
|[DSGN](https://github.com/Jia-Research-Lab/DSGN)|CVPR 2020|Stereo|95.42|86.03| 78.27|
|[Disp-RCNN](https://github.com/zju3dv/disprcnn)|CVPR 2020|Stereo |93.02 |	81.70 |	67.16|
|[MonoPair](https://arxiv.org/abs/2003.00504)|CVPR 2020|Monocular|91.65 |86.11 |76.45|
|[D4LCN](https://github.com/dingmyu/D4LCN)|CVPR 2020|Monocular|90.01|82.08| 63.98|
|[Kinematic3D](https://arxiv.org/abs/2007.09548)|ECCV 2020|Monocular|58.33 |	45.50 |	34.81|
|[MonoDLE](https://github.com/xinzhuma/monodle)|CVPR 2021|Monocular|93.46| 90.23| 80.11|
|[Ours](http://www.cvlibs.net/datasets/kitti/eval_object_detail.php?&result=e5233225fd5ef36fa63eb00252d9c00024961f2c)           |CVPR 2021 |Monocular|**96.11**|**91.23**|**80.96**|

## Inference/Deployment
Check instructions [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/inference.md) to **reproduce** the above quantitative results.

## Training
Check instructions [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/training.md) to train Ego-Net and learn how to prepare your own training dataset other than KITTI.

## Citation
Please star this repository and cite the following paper in your publications if it helps your research:

    @InProceedings{Li_2021_CVPR,
    author    = {Li, Shichao and Yan, Zengqiang and Li, Hongyang and Cheng, Kwang-Ting},
    title     = {Exploring intermediate representation for monocular vehicle pose estimation},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {1873-1883}
    }

## License
This repository can be used freely for non-commercial purposes. Contact me if you are interested in a commercial license.

## Links
Link to the paper:
[Exploring intermediate representation for monocular vehicle pose estimation](https://arxiv.org/abs/2011.08464)

Link to the presentation video:
[Youtube](https://www.youtube.com/watch?v=isKo0F3MU68), [爱奇艺](https://www.iqiyi.com/v_y6lrdy33kg.html)

Relevant ECCV 2020 work: [GSNet](https://github.com/lkeab/gsnet)


================================================
FILE: configs/KITTI_inference:demo.yml
================================================
# This is a YAML file storing experimental configurations for KITTI dataset

## general settings
name: 'refine a given set of predictions from D4LCN'
exp_type: 'inference'
model_type: 'heatmapModel'
use_gpu: True
use_pred_box: True
use_gt_box: True
gpu_id: [0]

## operations
train: False
save: False
evaluate: False
inference: True

## used directories
dirs:
    # output directory
    output: 'YOUR_OURPUT_DIR' 
    ckpt: 'YOUR_PRETRAINED_DIR' 
    load_prediction_file: '../resources/D4LCN/data'

## CUDNN settings
cudnn:
    enabled: True
    deterministic: True
    benchmark: False

## dataset settings
dataset:
    name: 'KITTI'
    split: 'valid'
    detect_classes: ['Car']
    3d_kpt_sample_style: 'bbox9'
    # interpolate the 3D bbox
    interpolate:
        flag: True
        style: 'bbox12'
        coef: [0.332, 0.667]
    # do some pre-processing
    pre-process: False
    root: 'YOUR_KITTI_DIR'
    # augmentation parameters
    scaling_factor: 0.2
    rotation_factor: 30. # degrees
    # pytorch image transformation setting
    pth_transform:
# mean: [0.485, 0.456, 0.406, 0., 0.] 
# std: [0.229, 0.224, 0.225, 1., 1.]    
        mean: [0.485, 0.456, 0.406] 
        std: [0.229, 0.224, 0.225]          
    # annotation style for 2d key-point
    2d_kpt_style: 'bbox9' # projected 3d bounding box corner and center points
    # input-output representation for 2d-to-3d lifting
    lft_in_rep: 'coordinates2d' # 2d coordinates on screen
    lft_out_rep: 'R3d+T' # 3d coordinates relative to centroid plus translation vector   

## model settings for a fully-connected network if used
FCModel:
    name: 'lifter'
    refine_3d: False 
    norm_twoD: False
    num_blocks: 2 
    input_size: 66 
    output_size: 96 
    num_neurons: 1024
    dropout: 0.5
    leaky: False

## settings for a fully-convolutional heatmap/coordinate regression model
heatmapModel:
    name: hrnet # here a high-resolution (hr) model is used
    add_xy: False # concatenate xy coodrinate maps along with the input
    jitter_bbox: False
    jitter_params:
        shift:
        - 0.1
        - 0.1
        scaling:
        - 0.4
        - 0.4
    input_size: 
    - 256
    - 256
    # rotate and scaling and input images
    augment_input: False
    # one can choose to regress dense semantic heatmaps or coordinates 
    head_type: 'coordinates'
    # up-sampling with pixel-shuffle
    pixel_shuffle: False
    # if an intermediate heatmap is produced
    heatmap_size:
    - 64
    - 64
    init_weights: true
    num_joints: 33
    extra:
        pretrained_layers:
        - 'conv1'
        - 'bn1'
        - 'conv2'
        - 'bn2'
        - 'layer1'
        - 'transition1'
        - 'stage2'
        - 'transition2'
        - 'stage3'
        - 'transition3'
        - 'stage4'
        final_conv_kernel: 1
        stage2:
            num_modules: 1
            num_branches: 2
            block: basic
            num_blocks:
            - 4
            - 4
            num_channels:
            - 48
            - 96
            fuse_method: sum
        stage3:
            num_modules: 4
            num_branches: 3
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            fuse_method: sum
        stage4:
            num_modules: 3
            num_branches: 4
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            - 384
            fuse_method: sum

## testing settings
testing_settings:
    batch_size: 1
    num_threads: 0
    shuffle: True
    pin_memory: False
    apply_dropout: False
    unnormalize: False
    alpha_mode: 'proj'

================================================
FILE: configs/KITTI_inference:test_submission.yml
================================================
# YAML file storing experimental configurations for KITTI dataset

## general settings
name: 'produce vehicle pose predictions on KITTI test split given detected bounding boxes'
exp_type: 'inference'
model_type: 'heatmapModel'
use_gpu: True
use_pred_box: True
use_gt_box: False
gpu_id: [0]

## operations
train: False
save: False
visualize: False # visualize during inference
batch_to_show: 1000000 # how many batches to visualize if needed
evaluate: False
inference: True
conf_thres: 0.1 # discard low score boxes

## used directories
dirs:
    # output directory
    output: 'YOUR_OURPUT_DIR' 
    ckpt: 'YOUR_PRETRAINED_DIR' 
    # raw detection results on test set by using RRC-Net
    load_prediction_file: '../resources/test_boxes'

## CUDNN settings
cudnn:
    enabled: True
    deterministic: True
    benchmark: False

## dataset settings
dataset:
    name: 'KITTI'
    split: 'test'
    detect_classes: ['Car']
    3d_kpt_sample_style: 'bbox9'
    # interpolate the 3D bbox
    interpolate:
        flag: True
        style: 'bbox12'
        coef: [0.332, 0.667]
    # do some pre-processing
    pre-process: False
    root: 'YOUR_KITTI_DIR'
    # augmentation parameters
    scaling_factor: 0.2
    rotation_factor: 30. # degrees
    # pytorch image transformation setting
    pth_transform:
        mean: [0.485, 0.456, 0.406] # TODO re-calculate this: R, G, B, X, Y 
        std: [0.229, 0.224, 0.225]          
    # annotation style for 2d key-point
    2d_kpt_style: 'bbox9' # projected 3d bounding box corner and center points
    # input-output representation for 2d-to-3d lifting
    lft_in_rep: 'coordinates2d' # 2d coordinates on screen
    lft_out_rep: 'R3d+T' # 3d coordinates relative to centroid plus translation vector   

## model settings for a fully-connected network if used
FCModel:
    name: 'lifter'
    refine_3d: False 
    norm_twoD: False
    num_blocks: 2 
    input_size: 66 
    output_size: 96 
    num_neurons: 1024
    dropout: 0.5
    leaky: False

## settings for a fully-convolutional heatmap regression model
heatmapModel:
    name: hrnet # here a high-resolution (hr) model is used
    add_xy: False # concatenate xy coodrinate maps along with the input
    jitter_bbox: True
    jitter_params:
        shift:
        - 0.1
        - 0.1
        scaling:
        - 0.4
        - 0.4
    input_size: 
    - 256
    - 256
    # rotate and scaling and input images
    augment_input: True
    head_type: 'coordinates'
    pixel_shuffle: False
    # if an intermediate heatmap is produced
    heatmap_size:
    - 64
    - 64
    init_weights: true
    num_joints: 33
    use_different_joints_weight: False
    extra:
        pretrained_layers:
        - 'conv1'
        - 'bn1'
        - 'conv2'
        - 'bn2'
        - 'layer1'
        - 'transition1'
        - 'stage2'
        - 'transition2'
        - 'stage3'
        - 'transition3'
        - 'stage4'
        final_conv_kernel: 1
        stage2:
            num_modules: 1
            num_branches: 2
            block: basic
            num_blocks:
            - 4
            - 4
            num_channels:
            - 48
            - 96
            fuse_method: sum
        stage3:
            num_modules: 4
            num_branches: 3
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            fuse_method: sum
        stage4:
            num_modules: 3
            num_branches: 4
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            - 384
            fuse_method: sum

## testing settings
testing_settings:
    batch_size: 1
    num_threads: 0
    shuffle: True
    pin_memory: False
    apply_dropout: False
    unnormalize: False
    alpha_mode: 'proj'

================================================
FILE: configs/KITTI_train_IGRs.yml
================================================
# YAML file storing experimental configurations for training on KITTI dataset

## general settings
name: 'kitti_kpt_loc'
exp_type: 'instanceto2d'
model_type: 'heatmapModel'
use_gpu: True
gpu_id: [0,1,2] # MODIFY this to the GPU/GPUs ids in your computer

## operations
train: True
save: True
visualize: False
evaluate: False

## output directories
dirs:
    # MODIFY them to your preferred directories
    output: '../outputs/training_record' 
    # This directory save intermediate training results (optional)
    debug: '../outputs/training_record/debug' 

## CUDNN settings
cudnn:
    enabled: True
    deterministic: False
    benchmark: False

## dataset settings
dataset:
    name: 'KITTI'
    detect_classes: ['Car']
    3d_kpt_sample_style: 'bbox9'
    interpolate:
        flag: True
        style: 'bbox12'
        coef: [0.332, 0.667]
    # do some pre-processing
    pre-process: False
    # MODIFY this to your KITTI directory
    root: '$YOUR_DIR/KITTI'
    # augmentation parameters
    scaling_factor: 0.2
    rotation_factor: 30. # degrees
    # pytorch image transformation setting
    pth_transform:
#        mean: [0.485, 0.456, 0.406, 0., 0.]
#        std: [0.229, 0.224, 0.225, 1., 1.]    
        mean: [0.485, 0.456, 0.406] 
        std: [0.229, 0.224, 0.225]    
    2d_kpt_style: 'bbox9'

## self-supervision settings
ss:
    flag: False
    # MODIFY this to your unlabeled image record if you enable self-supervised representation learning
    record_path: '$YOUR_DIR/Apollo_ss_record.npy'
    img_root: '$YOUR_DIR/ApolloScape/images'
    max_per_img: 6

## settings for a fully-convolutional heatmap/coordinate regression model
heatmapModel:
    name: hrnet # here a high-resolution (hr) model is used
    add_xy: False # concatenate xy coodrinate maps along with the input
    # data augmentation by adding noise to bounding box location
    jitter_bbox: True
    jitter_params:
        shift:
        - 0.1
        - 0.1
        scaling:
        - 0.4
        - 0.4
    input_size: 
    - 256
    - 256
    # rotate and scaling and input images
    augment_input: True
    head_type: 'coordinates'
    # up-sampling with pixel-shuffle
    pixel_shuffle: False
    # if an intermediate heatmap is produced
    heatmap_size:
    - 64
    - 64
    loss_type: JointsCompositeLoss
    # the following two settings are only valid for JointsCompositeLoss
    loss_spec_list: ['mse', 'l1', 'sl1']
    loss_weight_list: [1.0, 0.1, 'None']
    cr_loss_threshold: 0.15
    init_weights: true
    num_joints: 33
    #use_different_joints_weight: False
    # use a pre-trained checkpoint to initialize the model
    # MODIFY it to your own checkpoint directory
    pretrained: '../resources/start_point.pth'
    target_type: gaussian
    sigma: 1
    extra:
        pretrained_layers:
        - 'conv1'
        - 'bn1'
        - 'conv2'
        - 'bn2'
        - 'layer1'
        - 'transition1'
        - 'stage2'
        - 'transition2'
        - 'stage3'
        - 'transition3'
        - 'stage4'
        final_conv_kernel: 1
        stage2:
            num_modules: 1
            num_branches: 2
            block: basic
            num_blocks:
            - 4
            - 4
            num_channels:
            - 48
            - 96
            fuse_method: sum
        stage3:
            num_modules: 4
            num_branches: 3
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            fuse_method: sum
        stage4:
            num_modules: 3
            num_branches: 4
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            - 4
            num_channels:
            - 48
            - 96
            - 192
            - 384
            fuse_method: sum

## training settings  
training_settings:
    total_epochs: 45
    resume: False
    batch_size: 24
    num_threads: 16 # MODIFY this accordingly based on your machine
    shuffle: True
    pin_memory: False
    # weighted loss computation
    use_target_weight: False
    report_every: 30
    eval_every: 130
    eval_during: False # set this to True if you want to evaluate during training
    eval_metrics: ['JointDistance2DSIP']
    plot_loss: False
    # debugging configurations 
    debug: 
        save: True # save some intermeadiate images with keypoint prediction
        save_images_kpts: True
        save_hms_gt: True
        save_hms_pred: True

## testing settings
testing_settings:
    batch_size: 2
    num_threads: 4
    shuffle: False
    pin_memory: False
    apply_dropout: False
    unnormalize: False
    eval_metrics: ['JointDistance2DSIP']

## optimizer settings
optimizer:
    # for ADAM
    optim_type: 'adam'
    lr: 0.001
    weight_decay: 0.0
    # for SGD
    momentum: 0.9
    # learning rate decay
    milestones: [10, 20, 30, 40]
    gamma: 0.5


================================================
FILE: configs/KITTI_train_IGRs_Ped.yml
================================================
# YAML file storing experimental configurations for training on KITTI dataset for the Pedestrian class

## general settings
name: 'kitti_kpt_loc_pedestrian'
exp_type: 'instanceto2d'
# baselin
model_type: 'heatmapModel'
use_gpu: True
gpu_id: [0,1,]

## operations
train: True
save: True
visualize: False
evaluate: False

## output directories
dirs:
    # MODIFY them to your preferred directories
    output: '../outputs/training_record' 
    # This directory save intermediate training results (optional)
    debug: '../outputs/training_record/debug' 

## CUDNN settings
cudnn:
    enabled: True
    deterministic: False
    benchmark: False

## dataset settings
dataset:
    name: 'KITTI'
    detect_classes: ['Pedestrian']
    3d_kpt_sample_style: 'bbox9'
    interpolate:
        flag: True
        style: 'bbox12'
        coef: [0.332, 0.667]
    enlarge_factor: 1.05 # patch size parameter
    # do some pre-processing
    pre-process: True
    root: '/media/nicholas/Database/datasets/KITTI'
    # augmentation parameters
    scaling_factor: 0.2
    rotation_factor: 30. # degrees
    # pytorch image transformation setting
    pth_transform:  
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]    
    2d_kpt_style: 'bbox9'

## self-supervision settings
ss:
    flag: False
    # MODIFY this to your unlabeled image record if you enable self-supervised representation learning
    record_path: '$YOUR_DIR/Apollo_ss_record.npy'
    img_root: '$YOUR_DIR/ApolloScape/images'
    max_per_img: 6

## settings for a fully-convolutional heatmap regression model
heatmapModel:
    name: hrnet # here a high-resolution (hr) model is used
    add_xy: False # concatenate xy coodrinate maps along with the input
    jitter_bbox: True
    jitter_params:
        shift:
        - 0.1
        - 0.1
        scaling:
        - 0.2
        - 0.2
    input_size: 
    - 192
    - 256
    # rotate and scaling and input images
    augment_input: True
    head_type: 'coordinates'
    # up-sampling with pixel-shuffle
    pixel_shuffle: False
    # if an intermediate heatmap is produced
    heatmap_size:
    - 48
    - 64
    loss_type: JointsCompositeLoss
    # the following two settings are only valid for JointsCompositeLoss
    loss_spec_list: ['mse', 'l1', 'None']
    loss_weight_list: [1.0, 0.1, 0.]
    cr_loss_threshold: 0.1
    init_weights: true
    num_joints: 33
    use_different_joints_weight: False
    # use a pre-trained checkpoint to initialize the model
    # MODIFY it to your own checkpoint directory
    pretrained: '../resources/start_point.pth'
    target_type: gaussian
    sigma: 2
    extra:
        pretrained_layers:
        - 'conv1'
        - 'bn1'
        - 'conv2'
        - 'bn2'
        - 'layer1'
        - 'transition1'
        - 'stage2'
        - 'transition2'
        - 'stage3'
        - 'transition3'
        - 'stage4'
        freeze_layers:
        - 'conv1'
        - 'bn1'
        - 'conv2'
        - 'bn2'
        - 'layer1'
        - 'transition1'
        - 'stage2'        
        final_conv_kernel: 1
        stage2:
            num_modules: 1
            num_branches: 2
            block: basic
            num_blocks:
            - 4
            - 4
            num_channels:
            - 32
            - 64
            fuse_method: sum
        stage3:
            num_modules: 4
            num_branches: 3
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            num_channels:
            - 32
            - 64
            - 128
            fuse_method: sum
        stage4:
            num_modules: 3
            num_branches: 4
            block: basic
            num_blocks:
            - 4
            - 4
            - 4
            - 4
            num_channels:
            - 32
            - 64
            - 128
            - 256
            fuse_method: sum

## training settings  
training_settings:
    total_epochs: 40
    resume: False
    begin_epoch: 1
    end_epoch: 10
    snapshot_epochs: [20, 30, 40]
    batch_size: 2
    num_threads: 0
    shuffle: True
    pin_memory: False
    # weighted loss computation
    use_target_weight: False
    report_every: 100
    eval_every: 1000
    eval_during: True
    eval_metrics: ['JointDistance2DSIP']
    plot_loss: False
    # debugging configurations 
    debug: 
        save: True # save some intermeadiate results
        save_images_kpts: True
        save_hms_gt: True
        save_hms_pred: True

## testing settings
testing_settings:
    batch_size: 2
    num_threads: 0
    shuffle: False
    pin_memory: False
    apply_dropout: False
    unnormalize: False
    eval_metrics: ['JointDistance2DSIP']
    # save_debug: True
    save_debug: False

## optimizer settings
optimizer:
    # for ADAM
    optim_type: 'adam'
    lr: 0.001
    weight_decay: 0.0
    # for SGD
    momentum: 0.9
    # learning rate decay
    milestones: [10, 20, 30]
    gamma: 0.5


================================================
FILE: configs/KITTI_train_lifting.yml
================================================
# YAML file storing experimental configurations for KITTI dataset

## general settings
name: 'lifter'
exp_type: '2dto3d'
model_type: 'FCModel'
use_gpu: True
gpu_id: [1] # modify this to the GPU ids that you use 

## operations
train: True # perform training
save: True # save the trained model
visualize: False # visualize the training results
evaluate: False # perform evaluation

## paths to the relevant directories
dirs:
    # output directory
    output: '../outputs/training_record' 
    debug: '../outputs/training_record/debug'
    data_vis: '../outputs/training_record/data_vis'

## CUDNN settings
cudnn:
    enabled: True
    deterministic: False
    benchmark: False

## evaluation metrics
metrics:
    R3D:
        T_style: 'direct'
        R_style: 'euler'

## dataset settings
dataset:
    name: 'KITTI'
    detect_classes: ['Car'] # used class for training
    3d_kpt_sample_style: 'bbox9' # construct a cuboid for each 3D bounding box
    # interpolate the 3D bbox
    interpolate:
        flag: True
        style: 'bbox12'
        coef: [0.332, 0.667]
    # do some pre-processing
    pre-process: False
    root: '$YOUR_DIR/KITTI' # MODIFY this to your own path    
    # input-output representation for 2d-to-3d lifting
    lft_in_rep: 'coordinates2d' # 2d coordinates on screen
    lft_out_rep: 'R3d' # 3d coordinates relative to centroid plus translation vector

## optional cascaded regression
cascade: 
    num_stages: 1 # the default is simply no cascade

## model settings for a fully-connected network if used
FCModel:
    name: 'lifter'
    refine_3d: False 
    norm_twoD: False
    num_blocks: 2 
    num_neurons: 1024
    dropout: 0.5
    leaky: False
    loss_type: MSELoss1D
    loss_reduction: 'mean'

## training settings  
training_settings:
    # total_epochs: 300
    total_epochs: 1
    eval_start_epoch: 250 # start evaluation after this epoch
    resume: False
    batch_size: 2048
    num_threads: 4 # set the number of workers that works for your machine
    shuffle: True
    pin_memory: False
#    report_every: 500 # report every 500 batches
#    eval_every: 500 # test on the evaluation set every 500 batches
    report_every: 5 # report every 500 batches
    eval_every: 5 # test on the evaluation set every 500 batches
    eval_during: False # MODIFY this to True if you want to evaluate during the training process
    # how many times to augment data for 2D-to-3D lifting
    lft_aug: True
    lft_aug_times: 100
    # what evaluation metrics to use
    eval_metrics: ['RError3D']
    plot_loss: False # visualize the loss function during training 

## testing settings if used
testing_settings:
    apply_dropout: False
    unnormalize: True
    batch_size: 1024
    num_threads: 4
    shuffle: False
#    vis_epoch: 290 # start ploting after this epoch

## optimizer settings
optimizer:
    # for ADAM
    optim_type: 'adam'
    lr: 0.001
    weight_decay: 0.0
    # for SGD
    momentum: 0.9
    # learning rate will decay at each milestone epoch
    milestones: [50, 100, 150, 250]
    gamma: 0.5


================================================
FILE: docs/demo.md
================================================
Firstly you need to prepare the dataset and pre-trained models as described [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/preparation.md).

Then modify the directories by

```bash
cd ${EgoNet_DIR}/configs && vim KITTI_inference:demo.yml
```

Edit dirs:ckpt to your pre-trained model directory.

Edit dataset:root to your KITTI directory.

Finally, go to ${EgoNet_DIR}/tools and run

```bash
 python inference.py --cfg "../configs/KITTI_inference:demo.yml" --visualize True --batch_to_show 2
```

You can set --batch_to_show to other integers to see more results.

The visualized 3D bounding boxes are distinguished by their colors: 
1. Black indicates ground truth 3D boxes.
2. Magenta indicates 3D bounding boxes predicted by another 3D object detector ([D4LCN](https://github.com/dingmyu/D4LCN)).
3. Red indicates the predictions of Ego-Net, using the 2D bounding boxes from [D4LCN](https://github.com/dingmyu/D4LCN).
4. Yellow indicates the predictions of Ego-Net, using the ground truth 2D bounding boxes.


================================================
FILE: docs/inference.md
================================================
Firstly you need to prepare the dataset and pre-trained models as described [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/preparation.md).

## Reproduce D4LCN + EgoNet on the val split
You need to modify the directories by

```bash
cd ${EgoNet_DIR}/configs && vim KITTI_inference:demo.yml
```
Edit dirs:output to where you want to save the predictions.

Edit dirs:ckpt to your pre-trained model directory.

Edit dataset:root to your KITTI directory.

Finally, go to ${EgoNet_DIR}/tools and run

```bash
 python inference.py --cfg "../configs/KITTI_inference:demo.yml"
```

This will load D4LCN predictions, refine their vehicle orientation predictions and save the results.
The official evaluation program will automatically run to produce quantitative performance.

## Reproduce results on the test split
You need to modify the directories by

```bash
cd ${EgoNet_DIR}/configs && vim KITTI_inference:test_submission.yml
```
Edit dirs:output to where you want to save the predictions.

Edit dirs:ckpt to your pre-trained model directory.

Edit dataset:root to your KITTI directory.

Finally, go to ${EgoNet_DIR}/tools and run

```bash
 python inference.py --cfg "../configs/KITTI_inference:test_submission.yml"
```

This will load prepared 2D bounding boxes, predict the vehicle orientation and save the predictions.

Now you can zip the results and submit it to the [official server](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=2d)!

You can hit [91.23% AOS](http://www.cvlibs.net/datasets/kitti/eval_object_detail.php?&result=e5233225fd5ef36fa63eb00252d9c00024961f2c) for the moderate setting! This is the **most important** metric for joint vehicle detection and pose estimation on KITTI. You achieved this with a single RGB image without extra training data.


================================================
FILE: docs/preparation.md
================================================
## Data Preparation 
You need to download KITTI dataset [here](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d). Download left images, calibration files and labels.
Download the split files [here](https://drive.google.com/drive/folders/1YLtptqspOFw08QG2MsxewDT9tjF2O45g?usp=sharing) and place them at ${YOUR_KITTI_DIR}/SPLIT/ImageSets.
Your data folder should look like this:

   ```
   ${YOUR_KITTI_DIR}
   ├── training
      ├── calib
          ├── xxxxxx.txt (Camera parameters for image xxxxxx)
      ├── image_2
          ├── xxxxxx.png (image xxxxxx)
      ├── label_2
          ├── xxxxxx.txt (object labels for image xxxxxx)
      ├── ImageSets
         ├── train.txt
         ├── val.txt   
         ├── trainval.txt        
   ├── testing
      ├── calib
          ├── xxxxxx.txt (Camera parameters for image xxxxxx)
      ├── image_2
          ├── xxxxxx.png (image xxxxxx)
      ├── ImageSets
         ├── test.txt
   ```

## Download pre-trained model
You need to download the pre-trained checkpoints [here](https://drive.google.com/file/d/1JsVzw7HMfchxOXoXgvWG1I_bPRD1ierE/view?usp=sharing) in order to use Ego-Net. Unzip it to ${YOUR_MODEL_DIR}.

## Compile the official evaluator
Go to the folder storing the source code
```bash
cd ${EgoNet_DIR}/tools/kitti-eval 
```
Compile the source code
```bash
g++ -o evaluate_object_3d_offline evaluate_object_3d_offline.cpp -O3
```

## Download the input bounding boxes
Download the [resources folder](https://drive.google.com/drive/folders/1atfXLmsLFG6XEtNnwZuEYLydKqjr7Icf?usp=sharing) and unzip its contents. Place the resource folder at ${EgoNet_DIR}/resources


## Environment
You need to create an environment that meets the following dependencies. 
The versions included in the parenthesis are **tested**. Other versions may also work but are **not tested**.

- Python (3.7.9)
- Numpy (1.19.2)
- PyTorch (1.6.0, GPU required)
- Scipy (1.5.2)
- Matplotlib (3.3.4)
- OpenCV (3.4.2)
- pyyaml (5.4.1)

For more details of my tested local environment, refer to [spec-list.txt](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/spec-list.txt). 
The recommended environment manager is [Anaconda](https://www.anaconda.com/), which can create an environment using this provided spec-list. 
For debugging using an IDE, I personally use and recommend Spyder 4.2 which you can get by
```bash
conda install spyder
```


================================================
FILE: docs/spec-list.txt
================================================
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
@EXPLICIT
https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2021.1.19-h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.33.1-h53a641e_7.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.1.0-hdf63c60_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pandoc-2.12-h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-9.2-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.1.0-hdf63c60_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.8-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/expat-2.2.10-he6710b0_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/freeglut-3.0.0-hf484d3e_5.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/graphite2-1.3.14-h23475e2_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/icu-58.2-he6710b0_3.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.3-he6710b0_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libglu-9.0.0-hf484d3e_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libopus-1.3.1-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libsodium-1.0.18-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libspatialindex-1.9.3-h2531618_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.0.3-h1bed415_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libvpx-1.7.0-h439df22_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.2.0-h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.14-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.3-h2531618_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.2-he6710b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1k-h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pcre-8.44-he6710b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pixman-0.40.0-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.5-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/glib-2.68.0-h36276a3_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/hdf5-1.10.2-hba1933b_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/jasper-2.0.14-h07fcdf6_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.9.10-hb55368b_3.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/readline-8.1-h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.10-hbc83047_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/zeromq-4.3.4-h2531618_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.4.5-h9ceee32_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.18-hb2f20db_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.10.4-h5ab3b9f_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.0-h28cd5cc_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.2.0-h85742a9_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.35.2-hdfb4753_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ffmpeg-4.0-hcdf2ecd_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.13.1-h6c09931_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.0-h8213a91_2.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.11-h396b838_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.9-h7579374_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/alabaster-0.7.12-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/appdirs-1.4.4-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/argh-0.26.2-py37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/atomicwrites-1.4.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/attrs-20.3.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/backcall-0.2.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/cairo-1.16.0-hf32fb01_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/certifi-2020.12.5-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/chardet-4.0.0-py37h06a4308_1003.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/click-7.1.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/cloudpickle-1.6.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/colorama-0.4.4-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/decorator-4.4.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/defusedxml-0.7.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/diff-match-patch-20200713-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/docutils-0.16-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/entrypoints-0.3-py37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/future-0.18.2-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/idna-2.10-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/imagesize-1.2.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/ipython_genutils-0.2.0-pyhd3eb1b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/jeepney-0.6.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.3.1-py37h2531618_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/lazy-object-proxy-1.6.0-py37h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/markupsafe-1.1.1-py37h14c3975_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mccabe-0.6.1-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mistune-0.8.4-py37h14c3975_1001.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mypy_extensions-0.4.3-py37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.10.2-py37hff7bd54_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/olefile-0.46-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pandocfilters-1.4.3-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/parso-0.7.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pathspec-0.7.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pickleshare-0.7.5-pyhd3eb1b0_1003.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/psutil-5.8.0-py37h27cfd23_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/ptyprocess-0.7.0-pyhd3eb1b0_2.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pycodestyle-2.6.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.20-py_2.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyflakes-2.2.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyparsing-2.4.7-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pyrsistent-0.17.3-py37h7b6447c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pytz-2021.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyxdg-0.27-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-5.4.1-py37h27cfd23_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pyzmq-20.0.0-py37h2531618_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/qdarkstyle-2.8.1-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/qt-5.9.7-h5867ecd_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/qtpy-1.9.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/regex-2021.3.17-py37h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/rope-0.18.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/rtree-0.9.4-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/sip-4.19.8-py37hf484d3e_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/six-1.15.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/snowballstemmer-2.1.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sortedcontainers-2.3.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-applehelp-1.0.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-devhelp-1.0.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-htmlhelp-1.0.3-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-jsmath-1.0.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-qthelp-1.0.3-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinxcontrib-serializinghtml-1.1.4-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/testpath-0.4.4-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/textdistance-4.2.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/toml-0.10.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/tornado-6.1-py37h27cfd23_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/typed-ast-1.4.2-py37h27cfd23_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/typing_extensions-3.7.4.3-pyha847dfd_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ujson-4.0.2-py37h2531618_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/wcwidth-0.2.5-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/webencodings-0.5.1-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/wheel-0.36.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/wrapt-1.12.1-py37h7b6447c_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/wurlitzer-2.0.1-py37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/yapf-0.31.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/zipp-3.4.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/autopep8-1.5.6-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/babel-2.9.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/black-19.10b0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.14.5-py37h261ae71_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/cycler-0.10.0-py37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/harfbuzz-1.8.8-hffaf4a1_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/importlib-metadata-3.7.3-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/intervaltree-3.1.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/jedi-0.17.2-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/packaging-20.9-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pexpect-4.8.0-pyhd3eb1b0_3.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pillow-8.1.2-py37he98fc37_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/prompt-toolkit-3.0.17-pyh06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pydocstyle-6.0.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.9.2-py37h05f1152_2.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/python-dateutil-2.8.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/python-jsonrpc-server-0.4.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/qtawesome-1.0.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/setuptools-52.0.0-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/three-merge-0.1.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/traitlets-5.0.5-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/watchdog-1.0.2-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/astroid-2.5-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/bleach-3.3.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/brotlipy-0.7.0-py37h27cfd23_1003.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/cryptography-3.4.6-py37hd23ed53_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/flake8-3.9.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/importlib_metadata-3.7.3-hd3eb1b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/isort-5.8.0-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/jinja2-2.11.3-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/jupyter_core-4.7.1-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libopencv-3.4.2-hb342d67_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pip-21.0.1-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pygments-2.8.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ipython-7.21.0-py37hb070fc8_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/jsonschema-3.2.0-py_2.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/jupyter_client-6.1.7-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pluggy-0.13.1-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/pylint-2.7.2-py37h06a4308_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyopenssl-20.0.1-pyhd3eb1b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/secretstorage-3.3.1-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ipykernel-5.3.4-py37h5ca1d4c_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/keyring-22.3.0-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/nbformat-5.1.2-pyhd3eb1b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/python-language-server-0.36.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/urllib3-1.26.4-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/nbconvert-5.6.1-py37_1.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyls-black-0.4.6-hd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/pyls-spyder-0.3.2-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/qtconsole-5.0.3-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/requests-2.25.1-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/spyder-kernels-1.10.2-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/sphinx-3.5.3-pyhd3eb1b0_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/numpydoc-1.1.0-pyhd3eb1b0_1.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/spyder-4.2.4-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/imageio-2.9.0-py_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2020.2-254.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mkl-2020.2-256.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.3.0-py37he8ac12f_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.3.4-py37h06a4308_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.3.4-py37h62a2d02_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.3.0-py37h54f3939_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.1.1-py37h0573a6f_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.19.2-py37h54aff64_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.19.2-py37hfa32c7d_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/py-opencv-3.4.2-py37hb342d67_1.tar.bz2
https://conda.anaconda.org/pytorch/linux-64/pytorch-1.6.0-py3.7_cuda9.2.148_cudnn7.6.3_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.5.2-py37h0b6359f_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/opencv-3.4.2-py37h6fd60c2_1.tar.bz2
https://conda.anaconda.org/pytorch/linux-64/torchvision-0.7.0-py37_cu92.tar.bz2


================================================
FILE: docs/training.md
================================================
Firstly you need to prepare the dataset as described [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/preparation.md).

Then download a start point model [here](https://drive.google.com/file/d/1VFtMGgBG0cLGnbr3brrnPnJii2xGYj-9/view?usp=sharing) and place it at ${EgoNet_DIR}/resources. 

The training phase consists of two stages which are described as follows. 

For training on other datasets. You need to prepare the training images and camera parameters accordingly.

## Stage 1: train a lifter (L.pth)
You need to modify the configuration by

```bash
cd ${EgoNet_DIR}/configs && vim KITTI_train_lifting.yml
```
Edit dataset:root to your KITTI directory.

(Optional) Edit dirs:output to where you want to save the output model.

(Optional) You can evaluate during training by setting eval_during to True.

Finally, run

```bash
 cd tools
 python train_lifting.py --cfg "../configs/KITTI_train_lifting.yml"
```


## Stage 2: train the remaining part (HC.pth)
You need to modify the configuration by

```bash
cd ${EgoNet_DIR}/configs && vim KITTI_train_IGRs.yml
```

Edit dataset:root to your KITTI directory.

Edit gpu_id according to your local machine and set batch_size based on how much GPU memory you have. 

(Optional) Edit dirs:output to where you want to save the output model.

(Optional) You can evaluate during training by setting eval_during to True.

(Optional) Edit ss to enable self-supervised representation learning. You need to prepare unlabeled ApolloScape images and download record [here](https://drive.google.com/file/d/1uPdOC7LioomMF5DieUNrx3aZKsgobP5U/view?usp=sharing).

(Optional) Edit training_settings:debug to disable saveing intermediate training results.

Finally, run

```bash
 cd tools
 python train_IGRs.py --cfg "../configs/KITTI_train_IGRs.yml"
```


================================================
FILE: libs/arguments/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/arguments/parse.py
================================================
"""
Argument parser for command line inputs and experiment configuration file.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import yaml
import argparse

def read_yaml_file(path):
    """
    Read a .yml file.
    """
    try: 
        with open (path, 'r') as file:
            configs = yaml.safe_load(file)
    except Exception as e:
        print('Error reading the config file: ', e)
    return configs

def parse_args():
    """
    Read a .yml experiment configuration file whose path is provided by the user.
    
    You can add more arguments and modify configs accordingly.
    """
    parser = argparse.ArgumentParser(description='a general parser')
    # path to the configuration file
    parser.add_argument('--cfg',
                        help='experiment configuration file path',
                        type=str
                        )
    parser.add_argument('--visualize',
                        default=False,
                        type=bool
                        )    
    parser.add_argument('--batch_to_show',
                        default=1000000,
                        type=int
                        )    
    args, unknown = parser.parse_known_args()
    configs = read_yaml_file(args.cfg)   
    configs['config_path'] = args.cfg
    configs['visualize'] = args.visualize
    configs['batch_to_show'] = args.batch_to_show
    return configs

================================================
FILE: libs/common/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/common/format.py
================================================
"""
Methods for formatted output.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""
import os

from copy import deepcopy

def format_str_submission(roll, pitch, yaw, x, y, z, score):
    """
    Get a prediction string in ApolloScape style.
    """      
    tempt_str = "{pitch:.3f} {yaw:.3f} {roll:.3f} {x:.3f} {y:.3f} {z:.3f} {score:.3f}".format(
            pitch=pitch,
            yaw=yaw,
            roll=roll,
            x=x,
            y=y,
            z=z,
            score=score)
    return tempt_str

def get_instance_str(dic):
    """
    Produce KITTI style prediction string for one instance.
    """     
    string = ""
    string += dic['class'] + " "
    string += "{:.1f} ".format(dic['truncation'])
    string += "{:.1f} ".format(dic['occlusion'])
    string += "{:.6f} ".format(dic['alpha'])
    string += "{:.6f} {:.6f} {:.6f} {:.6f} ".format(dic['bbox'][0], dic['bbox'][1], dic['bbox'][2], dic['bbox'][3])
    string += "{:.6f} {:.6f} {:.6f} ".format(dic['dimensions'][1], dic['dimensions'][2], dic['dimensions'][0])
    string += "{:.6f} {:.6f} {:.6f} ".format(dic['locations'][0], dic['locations'][1], dic['locations'][2])
    string += "{:.6f} ".format(dic['rot_y'])
    if 'score' in dic:
        string += "{:.8f} ".format(dic['score'])
    else:
        string += "{:.8f} ".format(1.0)
    return string

def get_pred_str(record):
    """
    Produce KITTI style prediction string for a record dictionary.
    """      
    # replace the rotation predictions of input bounding boxes
    updated_txt = deepcopy(record['raw_txt_format'])
    for instance_id in range(len(record['euler_angles'])):
        updated_txt[instance_id]['rot_y'] = record['euler_angles'][instance_id, 1]
        updated_txt[instance_id]['alpha'] = record['alphas'][instance_id]
    pred_str = ""
    angles = record['euler_angles']
    for instance_id in range(len(angles)):
        # format a string for submission
        tempt_str = get_instance_str(updated_txt[instance_id])
        if instance_id != len(angles) - 1:
            tempt_str += '\n'
        pred_str += tempt_str
    return pred_str

def save_txt_file(img_path, prediction, params):
    """
    Save a txt file for predictions of an image.
    """    
    if not params['flag']:
        return
    file_name = img_path.split('/')[-1][:-3] + 'txt'
    save_path = os.path.join(params['save_dir'], file_name) 
    with open(save_path, 'w') as f:
        f.write(prediction['pred_str'])
    print('Wrote prediction file at {:s}'.format(save_path))
    return

================================================
FILE: libs/common/img_proc.py
================================================
"""
Image processing utilities.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import os

SIZE = 200.0

def transform_preds(coords, center, scale, output_size):
    """
    Transform local coordinates within a patch to screen coordinates.
    """      
    target_coords = np.zeros(coords.shape)
    trans = get_affine_transform(center, scale, 0, output_size, inv=1)
    for p in range(coords.shape[0]):
        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
    return target_coords

def get_affine_transform(center, 
                         scale, 
                         rot, 
                         output_size,
                         shift=np.array([0, 0], dtype=np.float32), 
                         inv=0
                         ):
    """
    Estimate an affine transformation given crop parameters (center, scale and
    rotation) and output resolution.                                                        
    """  
    if isinstance(scale, list):
        scale = np.array(scale)
    if isinstance(center, list):
        center = np.array(center)
    scale_tmp = scale * SIZE
    src_w = scale_tmp[0]
    dst_h, dst_w = output_size

    rot_rad = np.pi * rot / 180
    src_dir = get_dir([0, src_w * -0.5], rot_rad)
    dst_dir = np.array([0, dst_w * -0.5], np.float32)

    src = np.zeros((3, 2), dtype=np.float32)
    dst = np.zeros((3, 2), dtype=np.float32)
    src[0, :] = center + scale_tmp * shift
    src[1, :] = center + src_dir + scale_tmp * shift
    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir

    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

    if inv:
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
    else:
        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

    return trans

def affine_transform(pt, t):
    new_pt = np.array([pt[0], pt[1], 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2]

def affine_transform_modified(pts, t):
    """
    Apply affine transformation with homogeneous coordinates.                                                    
    """ 
    # pts of shape [n, 2]
    new_pts = np.hstack([pts, np.ones((len(pts), 1))]).T
    new_pts = t @ new_pts
    return new_pts[:2, :].T

def get_3rd_point(a, b):
    direct = a - b
    return b + np.array([-direct[1], direct[0]], dtype=np.float32)

def get_dir(src_point, rot_rad):
    sn, cs = np.sin(rot_rad), np.cos(rot_rad)

    src_result = [0, 0]
    src_result[0] = src_point[0] * cs - src_point[1] * sn
    src_result[1] = src_point[0] * sn + src_point[1] * cs

    return src_result

def crop(img, center, scale, output_size, rot=0):
    """
    A cropping function implemented as warping.                                                      
    """     
    trans = get_affine_transform(center, scale, rot, output_size)

    dst_img = cv2.warpAffine(img, 
                             trans, 
                             (int(output_size[0]), int(output_size[1])),
                             flags=cv2.INTER_LINEAR
                             )   

    return dst_img

def simple_crop(input_image, center, crop_size):
    """
    A simple cropping function without warping.
    """  
    assert len(input_image.shape) == 3, 'Unsupported image format.'
    channel = input_image.shape[2]
    # crop a rectangular region around the center in the image
    start_x = int(center[0] - crop_size[0])
    end_x = int(center[0] + crop_size[0]) 
    start_y = int(center[1] - crop_size[1])
    end_y = int(center[1] + crop_size[1])
    cropped = np.zeros((end_y - start_y, end_x - start_x, channel), 
                       dtype = input_image.dtype)
    # new bounding box index 
    new_start_x = max(-start_x, 0)
    new_end_x = min(input_image.shape[1], end_x) - start_x
    new_start_y = max(-start_y, 0)
    new_end_y = min(input_image.shape[0], end_y) - start_y
    # clamped old bounding box index
    old_start_x = max(start_x, 0)
    old_end_x = min(end_x, input_image.shape[1])
    old_start_y = max(start_y, 0)
    old_end_y = min(end_y, input_image.shape[0])
    try:
        cropped[new_start_y:new_end_y, new_start_x:new_end_x,:] = input_image[
            old_start_y:old_end_y, old_start_x:old_end_x,:]
    except ValueError:
        print('Error: cropping fails')
    return cropped

def np_random():
    """
    Return a random number sampled uniformly from [-1, 1]
    """
    return np.random.rand()*2 - 1

def jitter_bbox_with_kpts(old_bbox, joints, parameters):
    """
    Randomly shifting and resizeing a bounding box and mask out occluded joints.
    Used as data augmentation to improve robustness to detector noise.
    
    bbox: [x1, y1, x2, y2]
    joints: [N, 3]
    """
    new_joints = joints.copy()
    width, height = old_bbox[2] - old_bbox[0], old_bbox[3] - old_bbox[1]
    old_center = [0.5*(old_bbox[0] + old_bbox[2]), 
                  0.5*(old_bbox[1] + old_bbox[3])]
    horizontal_shift = parameters['shift'][0]*width*np_random()
    vertical_shift = parameters['shift'][1]*height*np_random()
    new_center = [old_center[0] + horizontal_shift,
                  old_center[1] + vertical_shift]
    horizontal_scaling = parameters['scaling'][0]*np_random() + 1
    vertical_scaling = parameters['scaling'][1]*np_random() + 1
    new_width = width*horizontal_scaling
    new_height = height*vertical_scaling
    new_bbox = [new_center[0] - 0.5*new_width, new_center[1] - 0.5*new_height,
                new_center[0] + 0.5*new_width, new_center[1] + 0.5*new_height]
    # predicate from upper left corner
    predicate1 = joints[:, :2] - np.array([[new_bbox[0], new_bbox[1]]])
    predicate1 = (predicate1 > 0.).prod(axis=1)
    # predicate from lower right corner
    predicate2 = joints[:, :2] - np.array([[new_bbox[2], new_bbox[3]]])
    predicate2 = (predicate2 < 0.).prod(axis=1)
    new_joints[:, 2] *= predicate1*predicate2
    return new_bbox, new_joints

def jitter_bbox_with_kpts_no_occlu(old_bbox, joints, parameters):
    """
    Similar to the function above, but does not produce occluded joints
    """
    width, height = old_bbox[2] - old_bbox[0], old_bbox[3] - old_bbox[1]
    old_center = [0.5 * (old_bbox[0] + old_bbox[2]), 
                  0.5 * (old_bbox[1] + old_bbox[3])]
    horizontal_scaling = parameters['scaling'][0] * np.random.rand() + 1
    vertical_scaling = parameters['scaling'][1] * np.random.rand() + 1
    horizontal_shift = 0.5 * (horizontal_scaling - 1) * width * np_random()
    vertical_shift = 0.5 * (vertical_scaling - 1) * height * np_random()
    new_center = [old_center[0] + horizontal_shift,
                  old_center[1] + vertical_shift]
    new_width = width * horizontal_scaling
    new_height = height * vertical_scaling
    new_bbox = [new_center[0] - 0.5 * new_width, new_center[1] - 0.5 * new_height,
                new_center[0] + 0.5 * new_width, new_center[1] + 0.5 * new_height]
    return new_bbox, joints

def generate_xy_map(bbox, resolution, global_size):
    """
    Generate the normalized coordinates as 2D maps which encodes location 
    information.
    
    bbox: [x1, y1, x2, y2] the local region
    resolution (height, width): target resolution
    global_size (height, width): the size of original image
    """
    map_width, map_height = resolution
    g_height, g_width = global_size
    x_start, x_end = 2*bbox[0]/g_width - 1, 2*bbox[2]/g_width - 1
    y_start, y_end = 2*bbox[1]/g_height - 1, 2*bbox[3]/g_height - 1
    x_map = np.tile(np.linspace(x_start, x_end, map_width), (map_height, 1))
    x_map = x_map.reshape(map_height, map_width, 1)
    y_map = np.linspace(y_start, y_end, map_height).reshape(map_height, 1)
    y_map = np.tile(y_map, (1, map_width))
    y_map = y_map.reshape(map_height, map_width, 1)
    return np.concatenate([x_map, y_map], axis=2)

def crop_single_instance(data_numpy, bbox, joints, parameters, pth_trans=None):
    """
    Crop an instance from an image given the bounding box and part coordinates.
    """
    reso = parameters['input_size'] # (height, width)
    transformed_joints = joints.copy()
    if parameters['jitter_bbox']:
        bbox, joints = jitter_bbox_with_kpts_no_occlu(bbox, 
                                                      joints,
                                                      parameters['jitter_params']
                                                      )
    joints_vis = joints[:, 2]
    if parameters['resize']:
        ret = resize_bbox(bbox[0], bbox[1], bbox[2], bbox[3], 
                          target_ar=reso[0]/reso[1])
        c, s = ret['c'], ret['s']
    else:
        c, s = bbox2cs(bbox)    
    trans = get_affine_transform(c, s, 0.0, reso)
    input = cv2.warpAffine(data_numpy,
                           trans,
                           (int(reso[1]), int(reso[0])),
                           flags=cv2.INTER_LINEAR
                           )
    # add two more channels to encode object location
    if parameters['add_xy']:
        xymap = generate_xy_map(ret['bbox'], reso, parameters['global_size'])
        input = np.concatenate([input, xymap.astype(np.float32)], axis=2)
    #cv2.imwrite('test.jpg', input)
    #input = torch.from_numpy(input.transpose(2,0,1))
    input = input if pth_trans is None else pth_trans(input)
    for i in range(len(joints)):
        if joints_vis[i] > 0.0:
            transformed_joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)   
    c = c.reshape(1, 2)
    s = s.reshape(1, 2)
    return input.unsqueeze(0), transformed_joints, c, s

def get_tensor_from_img(path, 
                        parameters,
                        sf=0.2, 
                        rf=30., 
                        r_prob=0.6, 
                        aug=False, 
                        rgb=True, 
                        joints=None,
                        global_box=None,
                        pth_trans=None,
                        generate_hm=False,
                        max_cnt=None
                        ):
    """
    Read image and apply data augmentation to obtain a tensor. 
    Keypoints are also transformed if given.
    
    path: image path
    c: cropping center
    s: cropping scale
    r: rotation
    reso: resolution of output image
    sf: scaling factor
    rf: rotation factor
    aug: apply data augmentation
    joints: key-point locations with optional visibility [N_instance, N_joint, 3]
    generate_hm: whether to generate heatmap based on joint locations
    """
#    data_numpy = cv2.imread(
#        path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
#        )
    data_numpy = cv2.imread(
        path, 1 | 128
        )    
    if data_numpy is None:
        raise ValueError('Fail to read {}'.format(path))    
    if rgb:
        data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
    all_inputs = []
    all_target = []
    all_centers = []
    all_scales = []
    all_target_weight = []
    # the dimension of the image
    parameters['global_size'] = data_numpy.shape[:-1]
    all_transformed_joints = []
    if parameters['reference'] == 'bbox':
        # crop around the given bounding boxes
        # bbox = [0, 0, data_numpy.shape[1] - 1, data_numpy.shape[0] - 1] \
        #     if 'bbox' not in parameters else parameters['bbox']
        bboxes = parameters['boxes'] # [N_instance, 4]
        for idx, bbox in enumerate(bboxes):
            input, transformed_joints, c, s = crop_single_instance(data_numpy,
                                                                   bbox,
                                                                   joints[idx],
                                                                   parameters,
                                                                   pth_trans
                                                                   )
            all_inputs.append(input)
            all_centers.append(c)
            all_scales.append(s)
        # s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
        # r = np.clip(np.random.randn() * rf, -rf, rf) if np.random.rand() <= r_prob else 0
            target = target_weight = 1.
            if generate_hm:
                target, target_weight = generate_target(transformed_joints, 
                                                        transformed_joints[:,2], 
                                                        parameters)
                target = torch.unsqueeze(torch.from_numpy(target), 0)
                target_weight = torch.unsqueeze(torch.from_numpy(target_weight), 0)
            all_target.append(target)
            all_target_weight.append(target_weight)
            all_transformed_joints.append(np.expand_dims(transformed_joints,0))
    all_transformed_joints = np.concatenate(all_transformed_joints)
    if max_cnt is not None and max_cnt < len(all_inputs):
        end = max_cnt
    else:
        end = len(all_inputs)
    end_indices = list(range(end))
    meta = {
        'path': path,
        'original_joints': joints[end_indices],
        'transformed_joints': all_transformed_joints[end_indices],
        'center': np.vstack(all_centers[:end]),
        'scale': np.vstack(all_scales[:end]),
        'joints_vis': all_transformed_joints[end_indices][:,:,2]
        # 'rotation': r,
    }
    inputs = torch.cat(all_inputs[:end], dim=0)
    if generate_hm:
        targets = torch.cat(all_target[:end], dim=0)
        target_weights = torch.cat(all_target_weight[:end], dim=0)
    else:
        targets, target_weights = None, None
    return inputs, targets, target_weights, meta

def generate_target(joints, joints_vis, parameters):
    """
    Generate heatmap targets by drawing Gaussian dots.
    
    joints:  [num_joints, 3]
    joints_vis: [num_joints]
    
    return: target, target_weight (1: visible, 0: invisible)
    """
    num_joints = parameters['num_joints']
    target_type = parameters['target_type']
    input_size = parameters['input_size']
    heatmap_size = parameters['heatmap_size']
    sigma = parameters['sigma']
    target_weight = np.ones((num_joints, 1), dtype=np.float32)
    target_weight[:, 0] = joints_vis

    
    assert target_type == 'gaussian', 'Only support gaussian map now!'

    if target_type == 'gaussian':
        target = np.zeros((num_joints, heatmap_size[0], heatmap_size[1]), 
                          dtype=np.float32)

        tmp_size = sigma * 3

        for joint_id in range(num_joints):
            if target_weight[joint_id] <= 0.5:
                continue
            feat_stride = input_size / heatmap_size
            mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
            mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
            # Check that any part of the gaussian is in-bounds
            ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
            br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
            if ul[0] >= heatmap_size[1] or ul[1] >= heatmap_size[0] \
                    or br[0] < 0 or br[1] < 0:
                # If not, just return the image as is
                target_weight[joint_id] = 0
                continue

            # # Generate gaussian
            size = 2 * tmp_size + 1
            x = np.arange(0, size, 1, np.float32)
            y = x[:, np.newaxis]
            x0 = y0 = size // 2
            # The gaussian is not normalized, we want the center value to equal 1
            g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))

            # Usable gaussian range
            g_x = max(0, -ul[0]), min(br[0], heatmap_size[1]) - ul[0]
            g_y = max(0, -ul[1]), min(br[1], heatmap_size[0]) - ul[1]
            # Image range
            img_x = max(0, ul[0]), min(br[0], heatmap_size[1])
            img_y = max(0, ul[1]), min(br[1], heatmap_size[0])

            target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
                g[g_y[0]:g_y[1], g_x[0]:g_x[1]]

    if parameters['use_different_joints_weight']:
        target_weight = np.multiply(target_weight, parameters['joints_weight'])

    return target, target_weight

def resize_bbox(left, top, right, bottom, target_ar=1.):
    """
    Resize a bounding box to pre-defined aspect ratio.
    """ 
    width = right - left
    height = bottom - top
    aspect_ratio = height/width
    center_x = (left + right)/2
    center_y = (top + bottom)/2
    if aspect_ratio > target_ar:
        new_width = height*(1/target_ar)
        new_left = center_x - 0.5*new_width
        new_right = center_x + 0.5*new_width
        new_top = top
        new_bottom = bottom        
    else:
        new_height = width*target_ar
        new_left = left
        new_right = right
        new_top = center_y - 0.5*new_height
        new_bottom = center_y + 0.5*new_height
    return {'bbox': [new_left, new_top, new_right, new_bottom],
            'c': np.array([center_x, center_y]),
            's': np.array([(new_right - new_left)/SIZE, (new_bottom - new_top)/SIZE])
            }

def enlarge_bbox(left, top, right, bottom, enlarge):
    """
    Enlarge a bounding box.
    """ 
    width = right - left
    height = bottom - top
    new_width = width * enlarge[0]
    new_height = height * enlarge[1]
    center_x = (left + right) / 2
    center_y = (top + bottom) / 2    
    new_left = center_x - 0.5 * new_width
    new_right = center_x + 0.5 * new_width
    new_top = center_y - 0.5 * new_height
    new_bottom = center_y + 0.5 * new_height
    return [new_left, new_top, new_right, new_bottom]

def modify_bbox(bbox, target_ar, enlarge=1.1):
    """
    Modify a bounding box by enlarging/resizing.
    """
    lbbox = enlarge_bbox(bbox[0], bbox[1], bbox[2], bbox[3], [enlarge, enlarge])
    ret = resize_bbox(lbbox[0], lbbox[1], lbbox[2], lbbox[3], target_ar=target_ar)
    return ret
    
def resize_crop(crop_size, target_ar=None):
    """
    Resize a crop size to a pre-defined aspect ratio.
    """    
    if target_ar is None:
        return crop_size
    width = crop_size[0]
    height = crop_size[1]
    aspect_ratio = height / width    
    if aspect_ratio > target_ar:
        new_width = height * (1 / target_ar)
        new_height = height
    else:
        new_height = width*target_ar
        new_width = width
    return [new_width, new_height]

def bbox2cs(bbox):
    """
    Convert bounding box annotation to center and scale.
    """  
    return [(bbox[0] + bbox[2]/2), (bbox[1] + bbox[3]/2)], \
        [(bbox[2] - bbox[0]/SIZE), (bbox[3] - bbox[1]/SIZE)]

def cs2bbox(center, size):
    """
    Convert center/scale to a bounding box annotation.
    """  
    x1 = center[0] - size[0]
    y1 = center[1] - size[1]
    x2 = center[0] + size[0]
    y2 = center[1] + size[1]
    return [x1, y1, x2, y2]

def kpts2cs(keypoints, 
            enlarge=1.1, 
            method='boundary', 
            target_ar=None, 
            use_visibility=True
            ):
    """
    Convert instance screen coordinates to cropping center and size
    
    keypoints of shape [n_joints, 2/3]
    """   
    assert keypoints.shape[1] in [2, 3], 'Unsupported input.'
    if keypoints.shape[1] == 2:
        visible_keypoints = keypoints
        vis_rate = 1.0
    elif keypoints.shape[1] == 3 and use_visibility:
        visible_indices = keypoints[:, 2].nonzero()[0]
        visible_keypoints = keypoints[visible_indices, :2]
        vis_rate = len(visible_keypoints)/len(keypoints)
    else:
        visible_keypoints = keypoints[:, :2]
        visible_indices = np.array(range(len(keypoints)))
        vis_rate = 1.0
    if method == 'centroid':
        center = np.ceil(visible_keypoints.mean(axis=0, keepdims=True))
        dif = np.abs(visible_keypoints - center).max(axis=0, keepdims=True)
        crop_size = np.ceil(dif*enlarge).squeeze()
        center = center.squeeze()
    elif method == 'boundary':
        left_top = visible_keypoints.min(axis=0, keepdims=True)
        right_bottom = visible_keypoints.max(axis=0, keepdims=True)
        center = ((left_top + right_bottom) / 2).squeeze()
        crop_size = ((right_bottom - left_top)*enlarge/2).squeeze()
    else:
        raise NotImplementedError
    # resize the bounding box to a specified aspect ratio
    crop_size = resize_crop(crop_size, target_ar)
    x1, y1, x2, y2 = cs2bbox(center, crop_size)

    new_origin = np.array([[x1, y1]], dtype=keypoints.dtype)
    new_keypoints = keypoints.copy()
    if keypoints.shape[1] == 2:
        new_keypoints = visible_keypoints - new_origin
    elif keypoints.shape[1] == 3: 
        new_keypoints[visible_indices, :2] = visible_keypoints - new_origin
    return center, crop_size, new_keypoints, vis_rate

def draw_bboxes(img_path, bboxes_dict, save_path=None):
    """
    Draw bounding boxes with OpenCV.
    """
    data_numpy = cv2.imread(img_path, 1 | 128)  
    for name, (color, bboxes) in bboxes_dict.items():
        for bbox in bboxes:
            start_point = (bbox[0], bbox[1])
            end_point = (bbox[2], bbox[3])
            cv2.rectangle(data_numpy, start_point, end_point, color, 2)
    if save_path is not None:
        cv2.imwrite(save_path, data_numpy)
    return data_numpy

def imread_rgb(img_path):
    """
    Read image with OpenCV.
    """    
    data_numpy = cv2.imread(img_path, 1 | 128)  
    data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
    return data_numpy

def save_cropped_patches(img_path, 
                         keypoints, 
                         save_dir="./", 
                         threshold=0.25,
                         enlarge=1.4, 
                         target_ar=None
                         ):
    """
    Crop instances from a image given part screen coordinates and save them.
    """
#    data_numpy = cv2.imread(
#        img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
#        )   
    data_numpy = cv2.imread(img_path, 1 | 128)  
    # data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
    # debug
    # import matplotlib.pyplot as plt
    # plt.imshow(data_numpy[:,:,::-1])
    # plt.plot(keypoints[0][:,0], keypoints[0][:,1], 'ro')
    # plt.pause(0.1)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    new_paths = []
    all_new_keypoints = []
    all_bbox = []
    for i in range(len(keypoints)):
        center, crop_size, new_keypoints, vis_rate = kpts2cs(keypoints[i], 
                                                             enlarge, 
                                                             target_ar=target_ar)
        all_bbox.append(list(map(int, cs2bbox(center, crop_size))))
        if vis_rate < threshold:
            continue
        all_new_keypoints.append(new_keypoints.reshape(1, keypoints.shape[1], -1))
        cropped = simple_crop(data_numpy, center, crop_size)
        save_path = os.path.join(save_dir, "instance_{:d}.jpg".format(i))
        new_paths.append(save_path)
        cv2.imwrite(save_path, cropped)
        del cropped
    if len(new_paths) == 0:
        # No instances cropped
        return new_paths, np.zeros((0, keypoints.shape[1], 3)), all_bbox
    else:
        return new_paths, np.concatenate(all_new_keypoints, axis=0), all_bbox
    
def get_max_preds(batch_heatmaps):
    """
    Get predictions from heatmaps with hard arg-max.
    
    batch_heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    """
    assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)
    maxvals = np.amax(heatmaps_reshaped, 2)

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals

def soft_arg_max_np(batch_heatmaps):
    """
    Soft-argmax instead of hard-argmax considering quantization errors.
    """
    assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    height = batch_heatmaps.shape[2]
    width = batch_heatmaps.shape[3]
    # get score/confidence for each joint
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    maxvals = np.amax(heatmaps_reshaped, 2)
    maxvals = maxvals.reshape((batch_size, num_joints, 1))    
    # normalize the heatmaps so that they sum to 1
    #assert batch_heatmaps.min() >= 0.0
    batch_heatmaps = np.clip(batch_heatmaps, a_min=0.0, a_max=None)
    temp_sum = heatmaps_reshaped.sum(axis = 2, keepdims=True)
    heatmaps_reshaped /= temp_sum
    ## another normalization method: softmax
    # spatial soft-max
    #heatmaps_reshaped = softmax(heatmaps_reshaped, axis=2)
    ##
    batch_heatmaps = heatmaps_reshaped.reshape(batch_size, num_joints, height, width)
    x = batch_heatmaps.sum(axis = 2)
    y = batch_heatmaps.sum(axis = 3)
    x_indices = np.arange(width).astype(np.float32).reshape(1,1,width)
    y_indices = np.arange(height).astype(np.float32).reshape(1,1,height)
    x *= x_indices
    y *= y_indices
    x = x.sum(axis = 2, keepdims=True)
    y = y.sum(axis = 2, keepdims=True)
    preds = np.concatenate([x, y], axis=2)
    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)
    preds *= pred_mask
    return preds, maxvals

def soft_arg_max(batch_heatmaps):
    """
    A pytorch version of soft-argmax
    """
    assert len(batch_heatmaps.shape) == 4, 'batch_images should be 4-ndim'
    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    height = batch_heatmaps.shape[2]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.view((batch_size, num_joints, -1))
    # get score/confidence for each joint    
    maxvals = heatmaps_reshaped.max(dim=2)[0]
    maxvals = maxvals.view((batch_size, num_joints, 1))       
    # normalize the heatmaps so that they sum to 1
    heatmaps_reshaped = F.softmax(heatmaps_reshaped, dim=2)
    batch_heatmaps = heatmaps_reshaped.view(batch_size, num_joints, height, width)
    x = batch_heatmaps.sum(dim = 2)
    y = batch_heatmaps.sum(dim = 3)
    x_indices = torch.arange(width).type(torch.cuda.FloatTensor)
    x_indices = torch.cuda.comm.broadcast(x_indices, devices=[x.device.index])[0]
    x_indices = x_indices.view(1, 1, width)
    y_indices = torch.arange(height).type(torch.cuda.FloatTensor)
    y_indices = torch.cuda.comm.broadcast(y_indices, devices=[y.device.index])[0]
    y_indices = y_indices.view(1, 1, height)    
    x *= x_indices
    y *= y_indices
    x = x.sum(dim = 2, keepdim=True)
    y = y.sum(dim = 2, keepdim=True)
    preds = torch.cat([x, y], dim=2)
    return preds, maxvals

def appro_cr(coordinates):
    """
    Approximate the square of cross-ratio along four ordered 2D points using 
    inner-product
    
    coordinates: PyTorch tensor of shape [4, 2]
    """
    AC = coordinates[2] - coordinates[0]
    BD = coordinates[3] - coordinates[1]
    BC = coordinates[2] - coordinates[1]
    AD = coordinates[3] - coordinates[0]
    return (AC.dot(AC) * BD.dot(BD)) / (BC.dot(BC) * AD.dot(AD))

def to_npy(tensor):
    """
    Convert PyTorch tensor to numpy array.
    """
    if isinstance(tensor, np.ndarray):
        return tensor
    else:
        return tensor.data.cpu().numpy()

================================================
FILE: libs/common/transformation.py
================================================
"""
Coordinate transformation functions.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import numpy as np
import cv2

def move_to(points, xyz=np.zeros((1,3))):
    # points of shape [n_points, 3]
    centroid = points.mean(axis=0, keepdims=True)
    return points - (centroid - xyz)

def world_to_camera_frame(P, R, T):
    """
    Convert points from world to camera coordinates
    
    P: Nx3 3d points in world coordinates
    R: 3x3 Camera rotation matrix
    T: 3x1 Camera translation parameters
    
    Returns
    X_cam: Nx3 3d points in camera coordinates
    """
    assert len(P.shape) == 2
    assert P.shape[1] == 3
    X_cam = R.dot( P.T - T ) # rotate and translate
    return X_cam.T

def camera_to_world_frame(P, R, T):
    """
    Inverse of world_to_camera_frame

    P: Nx3 points in camera coordinates
    R: 3x3 Camera rotation matrix
    T: 3x1 Camera translation parameters
    
    Returns
    X_cam: Nx3 points in world coordinates
    """
    assert len(P.shape) == 2
    assert P.shape[1] == 3
    X_cam = R.T.dot( P.T ) + T # rotate and translate
    return X_cam.T

def compute_similarity_transform(X, Y, compute_optimal_scale=False):
    """
    A port of MATLAB's `procrustes` function to Numpy.
    Adapted from http://stackoverflow.com/a/18927641/1884420
    
    Args
      X: array NxM of targets, with N number of points and M point dimensionality
      Y: array NxM of inputs
      compute_optimal_scale: whether we compute optimal scale or force it to be 1
    
    Returns:
      d: squared error after transformation
      Z: transformed Y
      T: computed rotation
      b: scaling
      c: translation
    """
    muX = X.mean(0)
    muY = Y.mean(0)
    X0 = X - muX
    Y0 = Y - muY
    ssX = (X0**2.).sum()
    ssY = (Y0**2.).sum()
    # centred Frobenius norm
    normX = np.sqrt(ssX)
    normY = np.sqrt(ssY)
    # scale to equal (unit) norm
    X0 = X0 / normX
    Y0 = Y0 / normY
    # optimum rotation matrix of Y
    A = np.dot(X0.T, Y0)
    U,s,Vt = np.linalg.svd(A,full_matrices=False)
    V = Vt.T
    T = np.dot(V, U.T)
    # Make sure we have a rotation
    detT = np.linalg.det(T)
    V[:,-1] *= np.sign( detT )
    s[-1]   *= np.sign( detT )
    T = np.dot(V, U.T)
    traceTA = s.sum()
    if compute_optimal_scale:  # Compute optimum scaling of Y.
        b = traceTA * normX / normY
        d = 1 - traceTA**2
        Z = normX*traceTA*np.dot(Y0, T) + muX
    else:  # If no scaling allowed
        b = 1
        d = 1 + ssY/ssX - 2 * traceTA * normY / normX
        Z = normY*np.dot(Y0, T) + muX
    c = muX - b*np.dot(muY, T)
    return d, Z, T, b, c

def compute_rigid_transform(X, Y, W=None, verbose=False):
    """
    A least-sqaure estimate of rigid transformation by SVD.
    
    Reference: https://content.sakai.rutgers.edu/access/content/group/
    7bee3f05-9013-4fc2-8743-3c5078742791/material/svd_ls_rotation.pdf
    
    X, Y: [d, N] N data points of dimention d
    W: [N, ] optional weight (importance) matrix for N data points
    """    
    assert len(X) == len(Y)
    assert (W is None) or (len(W.shape) in [1, 2])
    # find mean column wise
    centroid_X = np.mean(X, axis=1, keepdims=True)
    centroid_Y = np.mean(Y, axis=1, keepdims=True)
    # subtract mean
    Xm = X - centroid_X
    Ym = Y - centroid_Y
    if W is None:
        H = Xm @ Ym.T
    else:
        W = np.diag(W) if len(W.shape) == 1 else W
        H = Xm @ W @ Ym.T
    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        # special reflection case
        if verbose:
            print("det(R) < R, reflection detected!, correcting for it ...\n");
        # the global minimizer with a orthogonal transformation is not possible
        # the next best transformation is chosen
        Vt[-1,:] *= -1
        R = Vt.T @ U.T
    t = -R @ centroid_X + centroid_Y
    return R, t

def procrustes_transform(X, Y):
    """
    Compute a rigid transformation trans() from X to Y and return trans(X)
    """
    R, t = compute_rigid_transform(X, Y)
    return R @ X + t

def pnp_refine(prediction, observation, intrinsics, dist_coeffs):
    """
    Refine 3D prediction with observed image projection based on  the PnP algorithm.
    """
    (success, R, T) = cv2.solvePnP(prediction,
                                   observation,
                                   intrinsics,
                                   dist_coeffs,
                                   flags=cv2.SOLVEPNP_ITERATIVE)
    if not success:
        print('PnP failed.')
        return prediction
    else:
        refined_prediction = cv2.Rodrigues(R)[0] @ prediction.T + T    
        return refined_prediction

================================================
FILE: libs/common/utils.py
================================================
"""
Common utilities.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import torch
import torch.nn as nn
import numpy as np

from libs.metric.criterions import PCK_THRES

import os
from os.path import join as pjoin
from collections import namedtuple

def make_dir(name):
    """    
    Create a directory.
    """
    if not os.path.exists(os.path.dirname(name)):
        try:
            os.makedirs(os.path.dirname(name))
        except OSError as exc:
            print('make_dir failed.')
            raise exc
    return

def save_checkpoint(states, is_best, output_dir, filename='checkpoint.pth'):
    torch.save(states, pjoin(output_dir, filename))
    if is_best and 'state_dict' in states:
        torch.save(states['best_state_dict'], pjoin(output_dir, 'model_best.pth'))

def get_model_summary(model, *input_tensors, item_length=26, verbose=False):
    """
    Summarize a model. For now only convolution, batch normalization and 
    linear layers are considered for parameters and FLOPs.
    """
    summary = []
    ModuleDetails = namedtuple(
        "Layer", 
        ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]
        )
    hooks = []
    layer_instances = {}

    def hook(module, input, output):
        class_name = str(module.__class__.__name__)
        instance_index = 1
        if class_name not in layer_instances:
            layer_instances[class_name] = instance_index
        else:
            instance_index = layer_instances[class_name] + 1
            layer_instances[class_name] = instance_index
    
        layer_name = class_name + "_" + str(instance_index)
    
        params = 0
    
        if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
           class_name.find("Linear") != -1:
            for param_ in module.parameters():
                params += param_.view(-1).size(0)
    
        flops = "Not Available"
        if class_name.find("Conv") != -1 and hasattr(module, "weight"):
            flops = (
                torch.prod(
                    torch.LongTensor(list(module.weight.data.size()))) *
                torch.prod(
                    torch.LongTensor(list(output.size())[2:]))).item()
        elif isinstance(module, nn.Linear):
            flops = (torch.prod(torch.LongTensor(list(output.size()))) \
                     * input[0].size(1)).item()
    
        if isinstance(input[0], list):
            input = input[0]
        if isinstance(output, list):
            output = output[0]
    
        summary.append(
            ModuleDetails(
                name=layer_name,
                input_size=list(input[0].size()),
                output_size=list(output.size()),
                num_parameters=params,
                multiply_adds=flops)
        )

    def add_hooks(module):
        if not isinstance(module, nn.ModuleList) \
           and not isinstance(module, nn.Sequential) \
           and module != model:
            hooks.append(module.register_forward_hook(hook))

    model.eval()
    model.apply(add_hooks)

    space_len = item_length

    model(*input_tensors)
    for h in hooks:
        h.remove()

    details = ''
    if verbose:
        details = "Model Summary" + \
            os.linesep + \
            "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format(
                ' ' * (space_len - len("Name")),
                ' ' * (space_len - len("Input Size")),
                ' ' * (space_len - len("Output Size")),
                ' ' * (space_len - len("Parameters")),
                ' ' * (space_len - len("Multiply Adds (Flops)"))) \
                + os.linesep + '-' * space_len * 5 + os.linesep

    params_sum = 0
    flops_sum = 0
    for layer in summary:
        params_sum += layer.num_parameters
        if layer.multiply_adds != "Not Available":
            flops_sum += layer.multiply_adds
        if verbose:
            details += "{}{}{}{}{}{}{}{}{}{}".format(
                layer.name,
                ' ' * (space_len - len(layer.name)),
                layer.input_size,
                ' ' * (space_len - len(str(layer.input_size))),
                layer.output_size,
                ' ' * (space_len - len(str(layer.output_size))),
                layer.num_parameters,
                ' ' * (space_len - len(str(layer.num_parameters))),
                layer.multiply_adds,
                ' ' * (space_len - len(str(layer.multiply_adds)))) \
                + os.linesep + '-' * space_len * 5 + os.linesep

    details += os.linesep \
        + "Total Parameters: {:,}".format(params_sum) \
        + os.linesep + '-' * space_len * 5 + os.linesep
    details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \
        + os.linesep + '-' * space_len * 5 + os.linesep
    details += "Number of Layers" + os.linesep
    for layer in layer_instances:
        details += "{} : {} layers   ".format(layer, layer_instances[layer])

    return details

class AverageMeter(object):
    """
    An averaege meter object that computes and stores the average and current value.
    """
    def __init__(self):
        self.reset()
        self.PCK_stats = {}
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        return
    
    def update(self, val, n=1, others=None):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0
        if others is not None and'correct_cnt' in others:
            if 'sum' not in self.PCK_stats:
                self.PCK_stats['sum'] = np.zeros(len(others['correct_cnt'])) 
            self.PCK_stats['sum'] += others['correct_cnt']
            if 'total' not in self.PCK_stats:
                self.PCK_stats['total'] = 0.
            self.PCK_stats['total'] += n         
        return
    
    def print_content(self):
        if 'sum' in self.PCK_stats:
            for idx, value in enumerate(self.PCK_stats['sum']):
                PCK = value / self.PCK_stats['total']
                print('Average PCK at threshold {:.2f}: {:.3f}'.format(PCK_THRES[idx], PCK))
        return

================================================
FILE: libs/dataset/KITTI/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/dataset/KITTI/car_instance.py
================================================
"""
KITTI dataset implemented as PyTorch dataset object.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import libs.dataset.basic.basic_classes as bc
import libs.visualization.points as vp
import libs.common.img_proc as lip

from libs.common.utils import make_dir
from libs.common.img_proc import get_affine_transform

import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import cv2
import csv
import copy

from PIL import Image
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data.dataloader import default_collate
from os.path import join as pjoin
from os.path import sep as osep
from os.path import exists
from os import listdir 

# maximum number of instances to the network depending on your GPU memory
MAX_INS_CNT = 140
#MAX_INS_CNT = 64

TYPE_ID_CONVERSION = {
    'Car': 0,
    'Cyclist': 1,
    'Pedestrian': 2,
}

# annotation style of KITTI dataset
FIELDNAMES = ['type', 
              'truncated', 
              'occluded', 
              'alpha', 
              'xmin', 
              'ymin', 
              'xmax', 
              'ymax', 
              'dh', 
              'dw',
              'dl', 
              'lx', 
              'ly', 
              'lz', 
              'ry']

# the format of prediction has one more field: confidence score
FIELDNAMES_P = FIELDNAMES.copy() + ['score']

# indices used for performing interpolation
# key->value: style->index arrays
interp_dict = {
    'bbox12':(np.array([1,3,5,7,# h direction
                        1,2,3,4,# l direction
                        1,2,5,6]), # w direction
              np.array([2,4,6,8,
                        5,6,7,8,
                        3,4,7,8])
              ),
    'bbox12l':(np.array([1,2,3,4,]), # w direction
              np.array([5,6,7,8])
              ),
    'bbox12h':(np.array([1,3,5,7]), # w direction
              np.array([2,4,6,8])
              ),
    'bbox12w':(np.array([1,2,5,6]), # w direction
              np.array([3,4,7,8])
              ),    
    }

# indices used for computing the cross ratio
cr_indices_dict = {
    'bbox12':np.array([[ 1,  9, 21,  2],
                       [ 3, 10, 22,  4],
                       [ 5, 11, 23,  6],
                       [ 7, 12, 24,  8],
                       [ 1, 13, 25,  5],
                       [ 2, 14, 26,  6],
                       [ 3, 15, 27,  7],
                       [ 4, 16, 28,  8],
                       [ 1, 17, 29,  3],
                       [ 2, 18, 30,  4],
                       [ 5, 19, 31,  7],
                       [ 6, 20, 32,  8]]
                      )
    }

def get_cr_indices():
    """
    Helper function to define the indices used in computing the cross-ratio.
    """
    num_base_pts = 9
    num_lines = 12
    parents, children = interp_dict['bbox12']
    cr_indices = []
    for line_idx in range(num_lines):
        parent_idx = parents[line_idx] # first point
        child_idx = children[line_idx] # last point
        second_point_idx = num_base_pts + line_idx
        third_point_idx = num_base_pts + num_lines + line_idx
        cr_indices.append(np.array([parent_idx, 
                                   second_point_idx, 
                                   third_point_idx,
                                   child_idx]
                                  ).reshape(1,4)
                         )
    cr_indices = np.vstack(cr_indices)
    return cr_indices

class KITTI(bc.SupervisedDataset):
    """
    KITTI dataset.
    """    
    def __init__(self, cfgs, split, logger, scale=1.0):
        super().__init__(cfgs, split, logger)
        self.logger = logger
        self.logger.info("Initializing KITTI {:s} set, please wait...".format(split))
        self.exp_type = cfgs['exp_type'] # exp_type: experiment type 
        self._data_dir = cfgs['dataset']['root'] # root directory
        self._classes = cfgs['dataset']['detect_classes'] # used object classes
        self._get_data_parameters(cfgs) # initialize hyper-parameters
        self._set_paths() # initialize paths
        self._inference_mode = False 
        self.car_sizes = [] # dimension of cars
        self._load_image_list()
        if self.split in ['train', 'valid', 'trainvalid'] and \
            self.exp_type in ['instanceto2d', 'baselinealpha', 'baselinetheta']:
            # prepare local coordinates used in certain types of experiments
            self._prepare_key_points(cfgs)
            # save cropped car instances for debugging
            # cropped_path = pjoin(self._data_config['cropped_dir'], self.kpts_style,
            #                      self.split)
            # if not exists(cropped_path) and cfgs['dataset']['pre-process']:
            #     self._save_cropped_instances()            
        # prepare data used for future loading
        self.generate_pairs()
        # self.visualize()
        if self.split in ['train', 'trainvalid'] and self.exp_type in ['2dto3d']:
            # 2dto3d means the data is used by the lifter that predicts 3D 
            # cuboid based on 2D screen coordinates 
            self.normalize() # data normalization used for the lifter network
        if 'ss' in cfgs and cfgs['ss']['flag']:
            # use unlabeled images for weak self-supervision
            self.use_ss = True
            self.ss_settings = cfgs['ss']
            self._initialize_unlabeled_data(cfgs)
        self.logger.info("Initialization finished for KITTI {:s} set".format(split))
        # self.show_statistics()
        # debugging code if you need
        # test = self[10]
        # test = self.extract_ss_sample(1)
    
    def _get_image_path_list(self):
        """
        Prepare list of image paths for the used split.
        """
        assert 'image_name_list' in self._data_config
        image_path_list = []
        for name in self._data_config['image_name_list']:
            img_path = pjoin(self._data_config['image_dir'], name)
            image_path_list.append(img_path)
        self._data_config['image_path_list'] = image_path_list        
        return
    
    def _initialize_unlabeled_data(self, cfgs):
        """
        Initialize unlabeled data for self-supervision experiment.
        """
        self.ss_record = np.load(cfgs['ss']['record_path'], allow_pickle=True).item()
        self.logger.info('Found prepared self-supervision record at: ' + cfgs['ss']['record_path'])
        return
    
    def _load_image_list(self):
        """
        Prepare list of image names for the used split.
        """
        path = self._data_config[self.split + '_list']       
        with open(path, "r") as f:
            image_name_list = f.read().splitlines()
        for idx, line in enumerate(image_name_list):
            base_name = line.replace("\n", "")
            image_name = base_name + ".png"
            image_name_list[idx] = image_name
        self._data_config['image_name_list'] = image_name_list
        self._get_image_path_list()
        return
    
    def _check_precomputed_file(self, path, name):
        """
        Check if a pre-computed numpy file exists or not.
        """
        if exists(path):
            self.logger.info('Found prepared {0:s} at {1:s}'.format(name, path))
            value = np.load(path, allow_pickle=True).item()
            setattr(self, name, value)
            return True
        else:
            return False
        
    def _save_precomputed_file(self, data_dic, pre_computed_path, name):
        """
        Save a pre-computed numpy file.
        """
        setattr(self, name, data_dic)
        make_dir(pre_computed_path)
        np.save(pre_computed_path, data_dic)
        self.logger.info('Save prepared {0:s} at {1:s}'.format(name, pre_computed_path))        
        return
    
    def _prepare_key_points_custom(self, style, interp_params, vis_thresh=0.25):
        """
        Project 3D bounding boxes to image planes to prepare screen coordinates.
        """        
        assert 'keypoint_dir' in self._data_config
        kpt_dir = self._data_config['keypoint_dir']
        if interp_params['flag']:
            style += str(interp_params['coef'])
        pre_computed_path_kpts = pjoin(kpt_dir, '{0:s}_{1:s}_{2:s}.npy'.format(style, self.split, str(self._classes)))
        pre_computed_path_ids = pjoin(kpt_dir, '{0:s}_{1:s}_{2:s}_ids.npy'.format(style, self.split, str(self._classes)))        
        pre_computed_path_rots = pjoin(kpt_dir, '{0:s}_{1:s}_{2:s}_rots.npy'.format(style, self.split, str(self._classes)))   
        if self._check_precomputed_file(pre_computed_path_kpts, 'keypoints'):
            pass
        if self._check_precomputed_file(pre_computed_path_ids, 'instance_ids'):
            pass     
        if self._check_precomputed_file(pre_computed_path_rots, 'rotations'):
            return    
        path_list = self._data_config['image_path_list']
        data_dic_kpts = {}
        data_dic_ids = {}
        data_dic_rots = {}
        for path in path_list:
            image_name = path.split(osep)[-1]
            # instances that lie out of the image plane will be discarded 
            list_2d, _, list_id, _, list_rots = self.get_2d_3d_pair(path, 
                                                                    style=style, 
                                                                    augment=False,
                                                                    add_visibility=True,
                                                                    filter_outlier=True,
                                                                    add_rotation=True
                                                                    )  
            if len(list_2d) == 0:
                continue
            for idx, kpts in enumerate(list_2d):
                list_2d[idx] = kpts.reshape(1, -1, 3)
            data_dic_kpts[image_name] = np.concatenate(list_2d, axis=0)
            data_dic_ids[image_name] = list_id
            data_dic_rots[image_name] = np.concatenate(list_rots, axis=0)
        self._save_precomputed_file(data_dic_kpts, pre_computed_path_kpts, 'keypoints')
        self._save_precomputed_file(data_dic_ids, pre_computed_path_ids, 'instance_ids')  
        self._save_precomputed_file(data_dic_rots, pre_computed_path_rots, 'rotations') 
        return
    
    def _prepare_key_points(self, cfgs):
        self.kpts_style = cfgs['dataset']['2d_kpt_style']
        self._prepare_key_points_custom(self.kpts_style, cfgs['dataset']['interpolate'])
        if 'enlarge_factor' in cfgs['dataset']:
            self.enlarge_factor = cfgs['dataset']['enlarge_factor']
        else:
            self.enlarge_factor = 1.1
        return
    
    def _save_cropped_instances(self):
        # DEPRECATED, will be removed in a future release
        """ 
        Crop and save car instance images with given 2d key-points
        """        
        assert hasattr(self, 'keypoints')
        all_save_paths = []
        all_keypoints = []
        all_bbox = []
        target_ar = self.hm_para['target_ar']
        for image_name in self.keypoints.keys():
            image_path = pjoin(self._data_config['image_dir'], image_name)
            save_dir = pjoin(self._data_config['cropped_dir'], self.kpts_style,
                             self.split, image_name[:-4])
            keypoints = self.keypoints[image_name]
            new_paths, new_keypoints, bboxes = lip.save_cropped_patches(image_path, 
                                                                keypoints, 
                                                                save_dir, 
                                                                enlarge=self.enlarge_factor,
                                                                target_ar=target_ar)
            all_save_paths += new_paths
            all_keypoints.append(new_keypoints)
            all_bbox += bboxes
        annot_save_name = pjoin(self._data_config['cropped_dir'], 
                                self.kpts_style, self.split, 'annot.npy')
        np.save(annot_save_name, {'paths': all_save_paths,
                                  'kpts': np.concatenate(all_keypoints, axis=0),
                                  'global_box': all_bbox
                                  })
        return
    
    def _prepare_2d_pose_annot(self, threshold=4):
        """ 
        Prepare annotation for training the coordinate regression model.
        """          
        all_paths = []
        all_boxes = []
        all_rotations = []
        all_keypoints = []
        all_keypoints_raw = []
        for image_name in self.keypoints.keys():
            image_path = pjoin(self._data_config['image_dir'], image_name)
            # raw keypoints using camera projection
            keypoints = self.keypoints[image_name]
            rotations = self.rotations[image_name]
            boxes_img = []
            rots_img = []
            visible_kpts_img = []
            for i in range(len(keypoints)):
                # Note here severely-occluded instances are ignored in the trainign data
                visible_cnt = np.sum(keypoints[i][:, 2])
                if visible_cnt < threshold:
                    continue
                else:
                    # now set all keypoints as visible
                    tempt_kpts = keypoints[i][:,:2]
                    visible_kpts_img.append(np.expand_dims(tempt_kpts, 0))
                center, crop_size, new_keypoints, vis_rate = lip.kpts2cs(tempt_kpts, enlarge=self.enlarge_factor)
                bbox_instance = np.array((list(map(int, lip.cs2bbox(center, crop_size)))))
                boxes_img.append(bbox_instance.reshape(1,4))
                rots_img.append(rotations[i].reshape(1,2))
            if len(boxes_img) == 0:
                continue
            all_paths.append(image_path)            
            all_boxes.append(np.concatenate(boxes_img))
            all_rotations.append(np.concatenate(rots_img))
            all_keypoints.append(np.concatenate(visible_kpts_img))
            all_keypoints_raw.append(keypoints)
        return {'paths':all_paths, 
                'boxes':all_boxes, 
                'rots':all_rotations,
                'kpts':all_keypoints,
                'raw_kpts':all_keypoints_raw
                }
    
    def _prepare_detection_records(self, save=False, threshold = 0.1):
        # DEPRECATED UNTIL FURTHER UPDATE
        raise ValueError

    def gather_annotations(self, 
                           threshold=0.1, 
                           use_raw_bbox=False, 
                           add_gt=True,
                           filter_outlier=False
                           ):
        """ 
        Read ground truth 3D bounding box labels.
        """           
        path_list = self._data_config['image_path_list'] 
        record_dict = {}
        for img_path in path_list:
            image_name = img_path.split(osep)[-1]
            if self.split != 'test':
                # default: use gt label and calibration
                label_path = pjoin(self._data_config['label_dir'], 
                                   image_name[:-4] + '.txt'
                                   )
                self.read_single_file(image_name, 
                                      record_dict, 
                                      label_path=label_path,
                                      fieldnames=FIELDNAMES,
                                      add_gt=add_gt,
                                      use_raw_bbox=use_raw_bbox,
                                      filter_outlier=filter_outlier
                                      )
            else:
                record_dict[image_name] = {}
        self.annot_dict = record_dict
        return     
    
    def read_single_file(self, 
                         image_name, 
                         record_dict, 
                         label_path=None,
                         calib_path=None,
                         threshold=0.1,
                         fieldnames=FIELDNAMES_P,
                         add_gt=False,
                         use_raw_bbox=True,
                         filter_outlier=False,
                         bbox_only=False
                         ):
        """ 
        Read labels and prepare annotation for a single image.
        """  
        style = self._data_config['3d_kpt_sample_style']
        image_path = pjoin(self._data_config['image_dir'], image_name)
        if label_path is None:
            # default is ground truth annotation
            label_path = pjoin(self._data_config['label_dir'], image_name[:-3] + 'txt')
        if calib_path is None:
            calib_path = pjoin(self._data_config['calib_dir'], image_name[:-3] + 'txt')
        list_2d, list_3d, list_id, pv, raw_bboxes = self.get_2d_3d_pair(image_path,
                                                                        label_path=label_path,
                                                                        calib_path=calib_path,
                                                                        style=style,
                                                                        augment=False,
                                                                        add_raw_bbox=True,
                                                                        bbox_only=bbox_only,
                                                                        filter_outlier=filter_outlier,
                                                                        fieldnames=fieldnames # also load the confidence score
                                                                        )  
        if len(raw_bboxes) == 0:
            return False        
        if image_name not in record_dict:
            record_dict[image_name] = {}
        raw_annot, P = self.load_annotations(label_path, calib_path, fieldnames=fieldnames)
        # use different (slightly) intrinsic parameters for different images
        K = P[:, :3]  
        if len(list_2d) != 0:
            for idx, kpts in enumerate(list_2d):
                list_2d[idx] = kpts.reshape(1, -1, 3)
                list_3d[idx] = list_3d[idx].reshape(1, -1, 3)
            all_keypoints_2d = np.concatenate(list_2d, axis=0)
            all_keypoints_3d = np.concatenate(list_3d, axis=0)                       
            # compute 2D bounding box based on the projected 3D boxes
            bboxes_kpt = []
            for idx, keypoints in enumerate(all_keypoints_2d):
                # relatively tight bounding box: use enlarge = 1.0
                # delete invisible instances
                center, crop_size, _, _ = lip.kpts2cs(keypoints[:,:2],
                                                      enlarge=1.01)
                bbox = np.array(lip.cs2bbox(center, crop_size))             
                bboxes_kpt.append(np.array(bbox).reshape(1, 4))
            record_dict[image_name]['kpts_3d'] = all_keypoints_3d
            if add_gt:
                # special key name representing ground truth
                record_dict[image_name]['kpts'] = all_keypoints_2d
                record_dict[image_name]['kpts_3d_gt'] = all_keypoints_3d
        if use_raw_bbox:
            bboxes = np.vstack(raw_bboxes)
        elif len(bboxes_kpt) != 0:
            bboxes = np.vstack(bboxes_kpt)
            
        record_dict[image_name]['bbox_2d'] = bboxes
        record_dict[image_name]['raw_txt_format'] = raw_annot
        record_dict[image_name]['K'] = K
        # add some key-value pairs as ground truth annotation
        if add_gt:         
            pvs = np.vstack(pv) if len(pv) != 0 else []
            tempt_dic = {'boxes': bboxes,
                         'pose_vecs_gt':pvs
                         }
            record_dict[image_name] = {**record_dict[image_name], **tempt_dic}              
        return True
    
    def read_predictions(self, path):
        """
        Read the prediction files in the same format as the ground truth.
        """
        self.logger.info("Reading predictions from {:s}".format(path))
        file_list = listdir(path)  
        record_dict = {}
        use_raw_bbox = True if self.split == 'test' else False
        for file_name in file_list:
            if not file_name.endswith(".txt"):
                continue
            image_name = file_name[:-4] + ".png"
            label_path = pjoin(path, file_name)            
            self.read_single_file(image_name, 
                                  record_dict, 
                                  label_path=label_path,
                                  use_raw_bbox=use_raw_bbox
                                  )
        self.logger.info("Reading predictions finished.")
        return record_dict
    
    def _get_data_parameters(self, cfgs):
        """
        Initialize dataset-relevant parameters.
        """
        self._data_config = {}
        self._data_config['image_size_raw'] = NotImplemented
        if self.exp_type in ['2dto3d', 'inference', 'finetune']:
            # parameters relevant to input/output representation
            for key in ['3d_kpt_sample_style', 'lft_in_rep', 'lft_out_rep']:
                self._data_config[key] = cfgs['dataset'][key] 
        if self.exp_type in ['2dto3d']:  
            # parameters relevant to data augmentation              
            for key in ['lft_aug','lft_aug_times']:
                self._data_config[key] = cfgs['training_settings'][key]
        # parameters relevant to cuboid interpolation
        self.interp_params = cfgs['dataset']['interpolate']
        # parameters relevant to heatmap regression model and image data augmentation
        if 'heatmapModel' in cfgs:
            hm = cfgs['heatmapModel']
            jitter_flag = hm['jitter_bbox'] and self.split=='train' and cfgs['train']
            self.hm_para = {'reference': 'bbox',
                            'resize': True,
                            'add_xy': hm['add_xy'],
                            'jitter_bbox': jitter_flag,
                            'jitter_params': hm['jitter_params'],
                            # (height, width)
                            'input_size': np.array([hm['input_size'][1],
                                             hm['input_size'][0]]),
                            'heatmap_size': np.array([hm['heatmap_size'][1],
                                               hm['heatmap_size'][0]]),
                            'target_ar': hm['heatmap_size'][1]/hm['heatmap_size'][0],
                            'augment': hm['augment_input'],
                            'sf': cfgs['dataset']['scaling_factor'],
                            'rf': cfgs['dataset']['rotation_factor'],
                            'num_joints': hm['num_joints'],
                            'sigma': hm['sigma'] if 'sigma' in hm else None,
                            'target_type': hm['target_type'] if 'target_type' in hm else None,
                            'use_different_joints_weight': 
                                hm['use_different_joints_weight'] if 'use_different_joints_weight' in hm else None                            
                              }
            self.num_joints = hm['num_joints']
        # parameters relevant to PyTorch image transformation operations
        if 'pth_transform' in cfgs['dataset']:
            pth_transform = cfgs['dataset']['pth_transform']
            normalize = transforms.Normalize(
                mean=pth_transform['mean'], 
                std=pth_transform['std']
                )
            transform_list = [transforms.ToTensor(), normalize]
            if self.exp_type == 'detect2D' and self.split == 'train':
                transform_list.append(transforms.RandomHorizontalFlip(0.5))
            self.pth_trans = transforms.Compose(transform_list)           

    def _set_paths(self):
        """
        Initialize relevant directories.
        """
        ROOT = self.root
        split = self.split
        # validation set is a sub-set of the official training split
        # train/val/test: 3712/3769/7518
        split = 'train' if self.split == 'valid' else split
        split += 'ing'
        self._data_config['image_dir'] = pjoin(ROOT, split, 'image_2')
        self._data_config['cropped_dir'] = pjoin(ROOT, split, 'cropped')
        self._data_config['drawn_dir'] = pjoin(ROOT, split, 'drawn')
        self._data_config['label_dir'] = pjoin(ROOT, split, 'label_2')
        self._data_config['calib_dir'] = pjoin(ROOT, split, 'calib')
        self._data_config['keypoint_dir'] = pjoin(ROOT, split, 'keypoints')
        self._data_config['stats_dir'] = pjoin(ROOT, 'instance_stats.npy')
        # list of images for each sub-set
        self._data_config['train_list'] = pjoin(ROOT, 'training/ImageSets/train.txt')
        self._data_config['valid_list'] = pjoin(ROOT, 'training/ImageSets/val.txt')
        self._data_config['test_list'] = pjoin(ROOT, 'testing/ImageSets/test.txt')
        self._data_config['trainvalid_list'] = pjoin(ROOT, 'training/ImageSets/trainval.txt')        
        return
    
    def project_3d_to_2d(self, points, K):
        """ 
        Get 2D projection of 3D points in the camera coordinate system. 
        """          
        projected = K @ points.T
        projected[:2, :] /= projected[2, :]
        return projected
    
    def render_car(self, ax, K, obj_class, rot_y, locs, dimension, shift):
        # DEPRECATED
        cam_cord = []
        self.get_cam_cord(cam_cord, shift, rot_y, dimension, locs)
        # get 2D projections 
        projected = self.project_3d_to_2d(cam_cord[0], K)
        ax.plot(projected[0, :], projected[1, :], 'ro')
        vp.plot_3d_bbox(ax, projected[:2, 1:].T)
        return
    
    def show_statistics(self):
        # DEPRECATED
        path = self._data_config['stats_dir']       
        if self._check_precomputed_file(path, 'instance_stats') or self.split != 'train':
            return
        self.instance_statistics = {}
        if hasattr(self, 'car_sizes') and len(self.car_sizes) != 0:
            all_sizes = np.concatenate(self.car_sizes)
            fig, axes = plt.subplots(3,1)
            names = ['x', 'y', 'z']
            for axe_id in range(3):
                axes[axe_id].hist(all_sizes[:, axe_id])
                axes[axe_id].set_xlabel('Car size in {:s} direction'.format(names[axe_id]))
                axes[axe_id].set_ylabel('Counts')
            mean_size = all_sizes.mean(axis=0)
            std_size = all_sizes.std(axis=0)
            self.instance_statistics['size'] = {'mean':mean_size,
                                                'std': std_size
                                                }
            # prepare a reference 3D bounding box
            xmax, xmin = mean_size[0], -mean_size[0]
            ymax, ymin = mean_size[1], -mean_size[1]
            zmax, zmin = mean_size[2], -mean_size[2]
            bbox = np.array([[xmax, ymin, zmax],
                             [xmax, ymax, zmax],
                             [xmax, ymin, zmin],
                             [xmax, ymax, zmin],
                             [xmin, ymin, zmax],
                             [xmin, ymax, zmax],
                             [xmin, ymin, zmin],
                             [xmin, ymax, zmin]])
            bbox = np.vstack([np.array([[0., 0., 0.]]), bbox])            
            self.instance_statistics['ref_box3d'] = bbox
        self._save_precomputed_file(self.instance_statistics, path, 'instance_stats')            
        return
    
    def augment_pose_vector(self, 
                            locs,
                            rot_y,
                            obj_class,
                            dimension,
                            augment,
                            augment_times,
                            std_rot = np.array([15., 50., 15.])*np.pi/180.,
                            std_trans = np.array([0.2, 0.01, 0.2]),
                            ):
        """
        Data augmentation used for training the lifter sub-model.
        
        std_rot: standard deviation of rotation around x, y and z axis
        std_trans: standard deviation of translation along x, y and z axis
        """
        aug_ids, aug_pose_vecs = [], []
        aug_ids.append((obj_class, dimension))
        # KITTI only annotates rotation around y-axis (yaw)
        pose_vec = np.concatenate([locs, np.array([0., rot_y, 0.])]).reshape(1, 6)
        aug_pose_vecs.append(pose_vec)
        if not augment:
            return aug_ids, aug_pose_vecs
        rots_random = np.random.randn(augment_times, 3) * std_rot.reshape(1, 3)
        # y-axis
        rots_random[:, 1] += rot_y
        trans_random = 1 + np.random.randn(augment_times, 3) * std_trans.reshape(1, 3)
        trans_random *= locs.reshape(1, 3)
        for i in range(augment_times):
            # augment 6DoF pose
            aug_ids.append((obj_class, dimension))
            pose_vec = np.concatenate([trans_random[i], rots_random[i]]).reshape(1, 6)
            aug_pose_vecs.append(pose_vec)
        return aug_ids, aug_pose_vecs
    
    def get_representation(self, p2d, p3d, in_rep, out_rep):
        """
        Get input-output representations based on 3d point cloud and its 
        projected 2D screen coordinates.
        """        
        # input representation
        if len(p2d) > 0:
            num_kpts = len(p2d[0])
        if in_rep == 'coordinates2d':
            input_list = [points.reshape(1, num_kpts, -1) for points in p2d]
        elif in_rep == 'coordinates2d+area' and self._data_config['3d_kpt_sample_style'] == 'bbox9':
            # indices: [corner, neighbour1, neighbour2]
            indices = self.area_indices
            input_list = [vp.get_area(points, indices, True) for points in p2d]
        else:
            raise NotImplementedError('Undefined input representation.')
        # output representation
        if out_rep == 'R3d+T':
            # R3D stands for relative 3D shape, T stands for translation
            # center the camera coordinates to remove depth
            output_list = []
            for i in range(len(p3d)):
                # format: the root should be pre-computed as the first 3d point 
                root = p3d[i][[0], :]
                relative_shape = p3d[i][1:, :] - root
                output = np.concatenate([root, relative_shape], axis=0)
                output_list.append(output.reshape(1, -1)) 
        elif out_rep == 'R3d': # relative 3D shape
            output_list = []
            # save a copy of the 3D object roots
            if not hasattr(self, 'root_list'):
                self.root_list = []
            for i in range(len(p3d)):
                # format: the root should be pre-computed as the first 3d point 
                root = p3d[i][[0], :]
                self.root_list.append(root)
                relative_shape = p3d[i][1:, :] - root
                output_list.append(relative_shape.reshape(1, -1)) 
        else:
            raise NotImplementedError('undefined output representation.')
        return input_list, output_list
    
    def get_input_output_size(self):
        """
        Get the input/output size for 2d-to-3d lifting.
        """
        num_joints = self.num_joints
        if self._data_config['lft_in_rep'] == 'coordinates2d':
             input_size = num_joints*2
        else:
             raise NotImplementedError
        if self._data_config['lft_out_rep'] in ['R3d+T']:
             output_size = num_joints*3
        elif self._data_config['lft_out_rep'] in ['R3d']:
             output_size = (num_joints - 1) * 3             
        else:
             raise NotImplementedError        
        return input_size, output_size
    
    def interpolate(self, 
                    bbox_3d, 
                    style, 
                    interp_coef=[0.5], 
                    dimension=None, 
                    strings=['l','h','w']
                    ):
        """
        Interpolate 3d points on a 3D bounding box with a specified style.
        """
        if dimension is not None:
            # size-encoded representation
            l = dimension[0]
            if l < 3.5:
                style += 'l'
            elif l < 4.5:
                style += 'h'
            else:
                style += 'w'       
        pidx, cidx = interp_dict[style]
        parents, children = bbox_3d[:, pidx], bbox_3d[:, cidx]
        lines = children - parents
        new_joints = [(parents + interp_coef[i]*lines) for i in range(len(interp_coef))]
        return np.hstack([bbox_3d, np.hstack(new_joints)])
    
    def construct_box_3d(self, l, h, w, interp_params):
        """
        Construct 3D bounding box corners in the canonical pose.
        """        
        x_corners = [0.5*l, l, l, l, l, 0, 0, 0, 0]
        y_corners = [0.5*h, 0, h, 0, h, 0, h, 0, h]
        z_corners = [0.5*w, w, w, 0, 0, w, w, 0, 0]
        x_corners += - np.float32(l) / 2
        y_corners += - np.float32(h)
        z_corners += - np.float32(w) / 2
        corners_3d = np.array([x_corners, y_corners, z_corners])     
        if interp_params['flag']:
            corners_3d = self.interpolate(corners_3d, 
                                          interp_params['style'],
                                          interp_params['coef'],
                                          #dimension=np.array([l,h,w]) # dimension aware
                                          )
        return corners_3d
    
    def get_cam_cord(self, cam_cord, shift, ids, pose_vecs, rot_xz=False):
        """
        Construct 3D bounding box corners in the camera coordinate system.
        """         
        # does not augment the dimension for now
        dims = ids[0][1]
        l, h, w = dims[0], dims[1], dims[2]
        corners_3d_fixed = self.construct_box_3d(l, h, w, self.interp_params)
        for pose_vec in pose_vecs:
            # translation
            locs = pose_vec[0, :3]
            rots = pose_vec[0, 3:]
            x, y, z = locs[0], locs[1], locs[2] # bottom center of the labeled 3D box
            rx, ry, rz = rots[0], rots[1], rots[2]
            # This purturbation turns out to work well for rotation estimation
#            x *= (1 + np.random.randn()*0.1)
#            y *= (1 + np.random.randn()*0.05)
#            z *= (1 + np.random.randn()*0.1)
            if self.split == 'train' and self.exp_type == '2dto3d' and not self._inference_mode:
                ry += np.random.randn()*np.pi # random perturbation
            rot_maty = np.array([[np.cos(ry), 0, np.sin(ry)],
                                [0, 1, 0],
                                [-np.sin(ry), 0, np.cos(ry)]])
            if rot_xz:
                # rotation. Only yaw angle is considered in KITTI dataset
                rot_matx = np.array([[1, 0, 0],
                                    [0, np.cos(rx), -np.sin(rx)],
                                    [0, np.sin(rx), np.cos(rx)]])        
    
                rot_matz = np.array([[np.cos(rz), -np.sin(rz), 0],
                                    [np.sin(rz), np.cos(rz), 0],
                                    [0, 0, 1]])        
                # TODO: correct here
                rot_mat = rot_matz @ rot_maty @ rot_matx     
            else:
                rot_mat = rot_maty
            corners_3d = np.matmul(rot_mat, corners_3d_fixed)
            # translation
            corners_3d += np.array([x, y, z]).reshape([3, 1])
            camera_coordinates = corners_3d + shift
            cam_cord.append(camera_coordinates.T)
        return 
    
    def csv_read_annot(self, file_path, fieldnames):
        """
        Read instance attributes in the KITTI format. Instances not in the 
        selected class will be ignored. 
        
        A list of python dictionary is returned where each dictionary 
        represents one instsance.
        """        
        annotations = []
        with open(file_path, 'r') as csv_file:
            reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames)
            for line, row in enumerate(reader):
                if row["type"] in self._classes:
                    annot_dict = {
                        "class": row["type"],
                        "label": TYPE_ID_CONVERSION[row["type"]],
                        "truncation": float(row["truncated"]),
                        "occlusion": float(row["occluded"]),
                        "alpha": float(row["alpha"]),
                        "dimensions": [float(row['dl']), 
                                       float(row['dh']), 
                                       float(row['dw'])
                                       ],
                        "locations": [float(row['lx']), 
                                      float(row['ly']), 
                                      float(row['lz'])
                                      ],
                        "rot_y": float(row["ry"]),
                        "bbox": [float(row["xmin"]),
                                 float(row["ymin"]),
                                 float(row["xmax"]),
                                 float(row["ymax"])
                                 ]
                    }
                    if "score" in fieldnames:
                        annot_dict["score"] = float(row["score"])
                    annotations.append(annot_dict)        
        return annotations
    
    def csv_read_calib(self, file_path):
        """
        Read camera projection matrix in the KITTI format.
        """  
        with open(file_path, 'r') as csv_file:
            reader = csv.reader(csv_file, delimiter=' ')
            for line, row in enumerate(reader):
                if row[0] == 'P2:':
                    P = row[1:]
                    P = [float(i) for i in P]
                    P = np.array(P, dtype=np.float32).reshape(3, 4)
                    break        
        return P
    
    def load_annotations(self, label_path, calib_path, fieldnames=FIELDNAMES): 
        """
        Read 3D annotation and camera parameters.
        """          
        if self.split in ['train', 'valid', 'trainvalid', 'test']:
            annotations = self.csv_read_annot(label_path, fieldnames)
        # get camera intrinsic matrix K
        P = self.csv_read_calib(calib_path)
        return annotations, P
    
    def add_visibility(self, joints, img_width=1242, img_height=375):
        """
        Compute binary visibility of projected 2D parts.
        """  
        assert joints.shape[1] == 2
        visibility = np.ones((len(joints), 1))
        # predicate from upper left corner
        predicate1 = joints - np.array([[0., 0.]])
        predicate1 = (predicate1 > 0.).prod(axis=1)
        # predicate from lower right corner
        predicate2 = joints - np.array([[img_width, img_height]])
        predicate2 = (predicate2 < 0.).prod(axis=1)
        visibility[:, 0] *= predicate1*predicate2      
        return np.hstack([joints, visibility])
    
    def get_inlier_indices(self, p_2d, threshold=0.3):
        """
        Get indices of instances that are visible 'enough'.
        """  
        indices = []
        num_joints = p_2d[0].shape[0]
        for idx, kpts in enumerate(p_2d):
            if p_2d[idx][:, 2].sum() / num_joints >= threshold:
                indices.append(idx)        
        return indices
    
    def filter_outlier(self, p_2d, p_3d, threshold=0.3):
        """
        Keep instances that are visible 'enough'.
        """  
        p_2d_filtered, p_3d_filtered, indices = [], [], []
        num_joints = p_2d[0].shape[0]
        for idx, kpts in enumerate(p_2d):
            if p_2d[idx][:, 2].sum() / num_joints >= threshold:
                p_2d_filtered.append(p_2d[idx])
                p_3d_filtered.append(p_3d[idx])
                indices.append(idx)
        return p_2d_filtered, p_3d_filtered
    
    def get_img_size(self, path):
        """
        Get the resolution of an image without loading it.
        """
        with Image.open(path) as image:
            size = image.size 
        return size
    
    def get_2d_3d_pair(self, 
                       image_path, 
                       label_path=None,
                       calib_path=None,
                       style='null',
                       in_rep = 'coordinates2d',
                       out_rep = 'R3d+T',
                       augment=False, 
                       augment_times=1,
                       add_visibility=True,
                       add_raw_bbox=False, # add original bbox annotation from KITTI
                       add_rotation=False, # add orientation angles
                       bbox_only=False, # only returns raw bounding box
                       filter_outlier=True,
                       fieldnames=FIELDNAMES
                       ):
        """
        Get (input, output) pair used for training a lifter sub-model from a 
        single image.
        """
        image_name = image_path.split(osep)[-1]
        if label_path is None:
            # default is ground truth annotation
            label_path = pjoin(self._data_config['label_dir'], image_name[:-3] + 'txt')
        if calib_path is None:
            calib_path = pjoin(self._data_config['calib_dir'], image_name[:-3] + 'txt')
        anns, P = self.load_annotations(label_path, calib_path, fieldnames=fieldnames)
        # The intrinsics may vary slightly for different images
        # Yet one may convert them to a fixed one by applying a homography
        K = P[:, :3]
        # Debug: use pre-defined intrinsic parameters
        # K = np.array([[707.0493,   0.    , 604.0814],
        #               [  0.    , 707.0493, 180.5066],
        #               [  0.    ,   0.    ,   1.    ]], dtype=np.float32)
        shift = np.linalg.inv(K) @ P[:, 3].reshape(3,1)      
        # P containes intrinsics and extrinsics, I factorize P to K[I|K^-1t] 
        # and use extrinsics to compute the camera coordinate
        # here the extrinsics represent the shift between current camera to
        # the reference grayscale camera        
        # For more calibration details, refer to "Vision meets Robotics: The KITTI Dataset"
        camera_coordinates = []
        pose_vecs = []
        # id includes the class and size of the object
        ids = []
        if add_raw_bbox:
            bboxes = []
        if add_rotation:
            rotations = []
        for i, a in enumerate(anns):
            a = a.copy()
            obj_class = a["label"]
            dimension = a["dimensions"]
            locs = np.array(a["locations"])
            rot_y = np.array(a["rot_y"])
            if add_raw_bbox:
                bboxes.append(np.array(a["bbox"]).reshape(1,4))
            if add_rotation:
                rotations.append(np.array([a["alpha"], a["rot_y"]]).reshape(1,2))
            # apply data augmentation to represent a larger variation of
            # 3D pose and translation 
            if bbox_only:
                continue
            aug_ids, aug_pose_vecs = self.augment_pose_vector(locs,
                                                              rot_y,
                                                              obj_class,
                                                              dimension,
                                                              augment,
                                                              augment_times
                                                              )
            self.get_cam_cord(camera_coordinates, 
                              shift, 
                              aug_ids, 
                              aug_pose_vecs
                              )                    
            ids += aug_ids
            pose_vecs += aug_pose_vecs
        num_instances = len(camera_coordinates)
        # get 2D projections 
        if len(camera_coordinates) != 0:
            camera_coordinates = np.vstack(camera_coordinates)
            projected = self.project_3d_to_2d(camera_coordinates, K)[:2, :].T
            # target is camera coordinates
            p_2d = np.split(projected, num_instances, axis=0) 
            p_3d = np.split(camera_coordinates, num_instances, axis=0) 
            # set visibility to 0 if the projected keypoints lie out of the image plane
            if add_visibility:
                width, height = self.get_img_size(image_path)
                for idx, joints in enumerate(p_2d):
                    p_2d[idx] = self.add_visibility(joints, width, height)
            # filter out the instances that lie outside of the image
            if filter_outlier:
                indices = self.get_inlier_indices(p_2d)
                p_2d = [p_2d[idx] for idx in indices]
                p_3d = [p_3d[idx] for idx in indices]
                # p_2d, p_3d = self.filter_outlier(p_2d, p_3d)
            if filter_outlier and add_raw_bbox:
                bboxes = [bboxes[idx] for idx in indices]
            if filter_outlier and add_rotation:
                rotations = [rotations[idx] for idx in indices]            
            list_2d, list_3d = self.get_representation(p_2d, p_3d, in_rep, out_rep)

        else:
            list_2d, list_3d, ids, pose_vecs = [], [], [], []
        ret = list_2d, list_3d, ids, pose_vecs
        if add_raw_bbox:
            ret = ret + (bboxes, )
        if add_rotation:
            ret = ret + (rotations, )
        return ret            
    
    def show_annot(self, 
                   image_path, 
                   label_file=None, 
                   calib_file=None, 
                   save_dir=None
                   ):
        """
        Show the annotation of an image.
        """      
        image_name = image_path.split(osep)[-1]
        if label_file is None:
            label_file = pjoin(self._data_config['label_dir'], image_name[:-3] + 'txt')
        if calib_file is None:
            calib_file = pjoin(self._data_config['calib_dir'], image_name[:-3] + 'txt')
        anns, P = self.load_annotations(label_file, calib_file)
        K = P[:, :3]
        shift = np.linalg.inv(K) @ P[:, 3].reshape(3,1)        
        image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)[:, :, ::-1]
        fig1 = plt.figure(figsize=(11.3, 9))
        ax = plt.subplot(111)
        ax.imshow(image)
        fig2 = plt.figure(figsize=(11.3, 9))
        ax = plt.subplot(111)
        ax.imshow(image)        
        for i, a in enumerate(anns):
            a = a.copy()
            obj_class = a["label"]
            dimension = a["dimensions"]
            locs = np.array(a["locations"])
            rot_y = np.array(a["rot_y"])
            self.render_car(ax, K, obj_class, rot_y, locs, dimension, shift) 
        if save_dir is not None:
            output_path1 =  pjoin(save_dir, image_name + '_original.png')
            output_path2 = pjoin(save_dir, image_name + '_annotated.png')
            make_dir(output_path1)
            fig1.savefig(output_path1, dpi=300)
            fig2.savefig(output_path2, dpi=300)
        return
    
    def _generate_2d_3d_paris(self):
        """
        Prepare pair of 2D screen coordinates and 3D cuboid representation.
        
        """
        path_list = self._data_config['image_path_list']
        kpt_3d_style = self._data_config['3d_kpt_sample_style']
        in_rep = self._data_config['lft_in_rep']
        out_rep = self._data_config['lft_out_rep'] # R3d (Relative 3D shape) encodes 3D rotation
        input_list = []
        output_list = []
        id_list = []
        augment = self._data_config['lft_aug'] if self.split == 'train' else False
        augment_times = self._data_config['lft_aug_times']
        for path in path_list:
            list_2d, list_3d, ids, _ = self.get_2d_3d_pair(path, 
                                                           style=kpt_3d_style,
                                                           in_rep = in_rep,
                                                           out_rep = out_rep,
                                                           augment=augment,
                                                           augment_times=augment_times,
                                                           add_visibility=True
                                                           )            
            input_list += list_2d
            output_list += list_3d
            id_list += ids
        # does not use visibility as input
        num_instance = len(input_list)
        self.input = np.vstack(input_list)[:, :, :2].reshape(num_instance, -1)
        # use visibility as input
        # self.input = np.vstack(input_list).reshape(num_instance, -1)
        self.output = np.vstack(output_list) 
        if hasattr(self, 'root_list'):
            self.root_list = np.vstack(self.root_list)
        self.num_joints = int(self.input.shape[1]/2)      
        return
    
    def generate_pairs(self):
        """
        Prepare data (e.g., input-output pairs and metadata) that will be used 
        depending on the type of experiment.
        """
        if self.exp_type == '2dto3d':           
            # generate 2D screen coordinates and 3D cuboid
            self._generate_2d_3d_paris()
        elif self.exp_type in ['instanceto2d', 'baselinealpha', 'baselinetheta']:
            # # load the annotations containing cropped car instances 
            # path = pjoin(self._data_config['cropped_dir'], 
            #              self.kpts_style, self.split, 'annot.npy')
            # assert exists(path), 'Please prepare instance annotation first.'
            # self.annot_2dpose = np.load(path, allow_pickle=True).item() 
            self.annot_2dpose = self._prepare_2d_pose_annot()
        elif self.exp_type in ['detection2d']:
            self._prepare_detection_records()
            self.total_data = len(self.detection_records)
        elif self.exp_type == 'inference':
            self.gather_annotations()
            self.total_data = len(self.annot_dict)
            self.annoted_img_paths = list(self.annot_dict.keys())
        elif self.exp_type == 'finetune':
            self.gather_annotations(use_raw_bbox=False, 
                                    add_gt=True, 
                                    filter_outlier=True
                                    )
            self.total_data = len(self.annot_dict)
            self.annoted_img_paths = list(self.annot_dict.keys())            
        else:
            raise NotImplementedError('Unknown experiment type.')
        # count of total data
        if self.exp_type == '2dto3d':
            self.input = self.input.astype(np.float32())
            self.output = self.output.astype(np.float32())
            self.total_data = len(self.input)
        elif self.exp_type in ['instanceto2d', 'baselinealpha', 'baselinetheta']:
            self.total_data = len(self.annot_2dpose['paths'])
        return
    
    def visualize(self, plot_num = 1, save_dir=None):
        """
        Show some random images with annotations.
        """        
        path_list = self._data_config['image_path_list']
        chosen = np.random.choice(len(path_list), plot_num, replace=False)
        for img_idx in chosen:
            self.show_annot(path_list[img_idx], save_dir=save_dir)
        return
    
    def get_collate_fn(self):
        return my_collate_fn
    
    def inference(self, flags=[True, True]):
        self._inference_mode = flags[0]
        self._read_img_during_inference = flags[1]
    
    def extract_ss_sample(self, cnt):
        """
        Prepare data for self-supervised representation learning.
        """           
        # cnt: number of fully supervised samples
        extract_cnt = self.ss_settings['max_per_img'] - cnt
        if extract_cnt <= 0:
            num_channel = 5 if self.hm_para['add_xy'] else 3
            return torch.zeros(0, num_channel, 256, 256), None, None, None
        idx = np.random.randint(0, len(self.ss_record['paths']))
        parameters = self.hm_para
        parameters['boxes'] = self.ss_record['boxes'][idx]
        joints = self.ss_record['kpts'][idx]
        img_name = self.ss_record['paths'][idx].split(osep)[-1]
        img_path = pjoin(self.ss_settings['img_root'], img_name)
        image, target, weights, meta = lip.get_tensor_from_img(img_path, 
                                                               parameters, 
                                                               joints=joints,
                                                               pth_trans=self.pth_trans,
                                                               rf=parameters['rf'],
                                                               sf=parameters['sf'],
                                                               generate_hm=False,
                                                               max_cnt=extract_cnt
                                                               )        
        return image, target, weights, meta
    
    def prepare_ft_dict(self, idx):
        """
        Prepare data for fine-tuning.
        """  
        img_name = self.annoted_img_paths[idx]
        img_annot = self.annot_dict[img_name]
        ret = {}
        img_path = pjoin(self._data_config['image_dir'], img_name)
        kpts = img_annot['kpts']
        # the croping bounding box in the original image
        # global_box = self.annot_2dpose['global_box'][idx]
        parameters = self.hm_para
        parameters['boxes'] = img_annot['bbox_2d']
        # fs: fully-supervised ss: self-supervised
        images_fs, heatmaps_fs, weights_fs, meta_fs = lip.get_tensor_from_img(img_path, 
                                                                              parameters, 
                                                                              joints=kpts,
                                                                              pth_trans=self.pth_trans,
                                                                              rf=parameters['rf'],
                                                                              sf=parameters['sf'],
                                                                              generate_hm=True)
        ret['path'] = img_path
        ret['images_fs'] = images_fs
        ret['heatmaps_fs'] = heatmaps_fs
        # ret['meta_fs'] = meta_fs
        ret['kpts_3d'] = img_annot['kpts_3d']
        ret['crop_center'] = meta_fs['center']
        ret['crop_scale'] = meta_fs['scale']
        ret['kpts_local'] = meta_fs['transformed_joints']
        # prepare the affine transformation matrices so map local coordinates
        # back to global screen coordinates
        ret['af_mats'] = []
        for idx in range(len(ret['crop_center'])):
            trans_inv = get_affine_transform(ret['crop_center'][idx],
                                             ret['crop_scale'][idx], 
                                             0., 
                                             self.hm_para['input_size'], 
                                             inv=1)  
            ret['af_mats'].append(trans_inv)
        # use random unlabeled images for data augmentation
        if self.split == 'train' and self.use_ss:
            images_ss, heatmaps_ss, weights_ss, meta_ss = self.extract_ss_sample(len(images_fs))
            ret['images_ss'] = images_ss
            ret['meta_ss'] = meta_ss
        return ret
    
    def __getitem__(self, idx):
        """
        Required by dataloader.
        """  
        # only return testing images during inference
        if self.split == 'test' or self._inference_mode:
            #TODO: consider classes except for cars in the future
            img_name = self.annoted_img_paths[idx]
            # debug: use a specified image for visualization
            # img_name = "006658.png"
            img_path = pjoin(self._data_config['image_dir'], img_name)
            if self._read_img_during_inference:
                image = lip.imread_rgb(img_path)
            else:
                image = None
            if self._read_img_during_inference and hasattr(self, 'pth_trans'):
                # pytorch transformation if provided
                image = self.pth_trans(image)
            record = {'path':img_path}
            # add other available annotations
            if hasattr(self, 'annot_dict'):
                record = {**record, **self.annot_dict[img_name]}
            return image, record
        # for training and validation splits
        if self.exp_type == '2dto3d':
            # the 2D-3D pairs are stored in RAM
            meta_data = {}
            # the 3D global position
            if hasattr(self, 'root_list'):
                meta_data['roots'] = self.root_list[idx]
            return self.input[idx], self.output[idx], np.zeros((0,1)), meta_data
        elif self.exp_type in ['baselinealpha', 'baselinetheta']:
            img_path = self.annot_2dpose['paths'][idx]
            rots = self.annot_2dpose['rots'][idx]
            kpts = self.annot_2dpose['kpts'][idx]
            if kpts.shape[2] == 2:
                kpts = np.concatenate([kpts, np.ones((kpts.shape[0], kpts.shape[1], 1))], axis=2)            
            parameters = self.hm_para
            parameters['boxes'] = self.annot_2dpose['boxes'][idx]
            images_fs, heatmaps_fs, weights_fs, meta_fs = lip.get_tensor_from_img(img_path, 
                                                                                  parameters, 
                                                                                  joints=kpts,
                                                                                  pth_trans=self.pth_trans,
                                                                                  rf=parameters['rf'],
                                                                                  sf=parameters['sf'],
                                                                                  generate_hm=False
                                                                                  )
            if self.exp_type == 'baselinealpha':
                targets = [np.array([[np.cos(rots[idx][0]), np.sin(rots[idx][0])]])  for idx in range(len(rots))]
                meta_fs['angles_gt'] = rots[:, 0]
            elif self.exp_type == 'baselinetheta':
                targets = [np.array([[np.cos(rots[idx][1]), np.sin(rots[idx][1])]]) for idx in range(len(rots))]
                meta_fs['angles_gt'] = rots[:, 1]
            targets = torch.from_numpy(np.concatenate(targets).astype(np.float32))
            return images_fs, targets, weights_fs, meta_fs
        elif self.exp_type == 'instanceto2d':
            # the input images and target heatmaps are produced online
            img_path = self.annot_2dpose['paths'][idx]
            kpts = self.annot_2dpose['kpts'][idx]
            # the croping bounding box in the original image
            # global_box = self.annot_2dpose['global_box'][idx]
            if kpts.shape[2] == 2:
                kpts = np.concatenate([kpts, np.ones((kpts.shape[0], kpts.shape[1], 1))], axis=2)
            parameters = self.hm_para
            parameters['boxes'] = self.annot_2dpose['boxes'][idx]
            # fs: fully-supervised ss: self-supervised
            images_fs, heatmaps_fs, weights_fs, meta_fs = \
                lip.get_tensor_from_img(img_path, 
                                        parameters, 
                                        joints=kpts,
                                        pth_trans=self.pth_trans,
                                        rf=parameters['rf'],
                                        sf=parameters['sf'],
                                        generate_hm=True
                                        )
            # use random unlabeled images for data augmentation
            if self.split == 'train' and hasattr(self, 'use_ss') and self.use_ss:
                images_ss, heatmaps_ss, weights_ss, meta_ss = self.extract_ss_sample(len(images_fs))
                images = [images_fs, images_ss]
                targets = heatmaps_fs
                weights = weights_fs
                meta = meta_fs
            else:
                images = images_fs
                targets = heatmaps_fs
                weights = weights_fs
                meta = meta_fs
            return images, targets, weights, meta
        elif self.exp_type == 'detection2d':
            record = copy.deepcopy(self.detection_records[idx])
            path = record['path']
            image = lip.imread_rgb(path)
            target = record['target']
            if hasattr(self, 'pth_trans'):
                # pytorch transformation if provided
                image = self.pth_trans(image)
            return image, target
        elif self.exp_type == 'finetune':
            # prepare images, 2D and 3D annotations as a dictionary for finetuning 
            ret = self.prepare_ft_dict(idx)
            return ret
        else:
            raise NotImplementedError

def prepare_data(cfgs, logger):
    """
    Prepare training and validation dataset objects.
    """  
    train_set = KITTI(cfgs, 'train', logger)
    valid_set = KITTI(cfgs, 'valid', logger)
    if cfgs['exp_type'] == '2dto3d':
        # normalize 2D keypoints
        valid_set.normalize(train_set.statistics)
    return train_set, valid_set

def get_dataset(cfgs, logger, split):
    return KITTI(cfgs, split, logger)

def collate_dict(dict_list):
    ret = {}
    ret['path'] = [item['path'] for item in dict_list]
    for key in dict_list[0]:
        if key == 'path':
            continue
        ret[key] = np.concatenate([d[key] for d in dict_list], axis=0)
    return ret

def length_limit(instances, targets, target_weights, meta):
    if len(instances) > MAX_INS_CNT and len(instances) == len(targets):
        # normal training
        chosen = np.random.choice(len(instances), MAX_INS_CNT, replace=False)
        ins, tar, tw, = instances[chosen], targets[chosen], target_weights[chosen]
        m = {'path':meta['path']}
        for key in meta:
            if key != 'path':
                m[key] = meta[key][chosen]
    elif len(instances) > MAX_INS_CNT and len(instances) > len(targets) and meta['fs_instance_cnt'] > MAX_INS_CNT:
        # mixed training: fully-supervised instances are too many
        chosen = np.random.choice(meta['fs_instance_cnt'], MAX_INS_CNT, replace=False)
        ins, tar, tw, = instances[chosen], targets[chosen], target_weights[chosen]
        m = {'path':meta['path']}
        for key in meta:
            if key != 'path' and key != 'fs_instance_cnt':
                m[key] = meta[key][chosen]
    elif len(instances) > MAX_INS_CNT and len(instances) > len(targets) and meta['fs_instance_cnt'] <= MAX_INS_CNT:
        # mixed training: self-supervised instances are too many
        ins, tar, tw, m = instances[:MAX_INS_CNT], targets, target_weights, meta
    else:
        ins, tar, tw, m = instances, targets, target_weights, meta
    return ins, tar, tw, m

def my_collate_fn(batch):
    # the collate function for 2d pose training
    instances, targets, target_weights, meta = list(zip(*batch))
    if isinstance(instances[0], list):
        # each batch comes in the format of (fs_instances, ss_instances)
        fs_instances, ss_instances = list(zip(*instances))
        fs_instances = torch.cat(fs_instances)
        ss_instances = torch.cat(ss_instances)
        instances = torch.cat([fs_instances, ss_instances])
        targets = torch.cat(targets, dim=0)
        # target_weights = torch.cat(target_weights, dim=0)
        meta = collate_dict(meta)
        meta['fs_instance_cnt'] = len(fs_instances)
    else:
        instances = torch.cat(instances, dim=0)
        targets = torch.cat(targets, dim=0)
        # target_weights = torch.cat(target_weights, dim=0)
        meta = collate_dict(meta)
    if target_weights[0] is not None:
        target_weights = torch.cat(target_weights, dim=0)
    else:
        #dummy weight
        target_weights = torch.ones(1)
    return length_limit(instances, targets, target_weights, meta)

================================================
FILE: libs/dataset/__init__.py
================================================
#import libs.dataset.ApolloScape
import libs.dataset.KITTI



================================================
FILE: libs/dataset/basic/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/dataset/basic/basic_classes.py
================================================
"""
Basic classes for customized dataset classes to inherit.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""
import torch.utils.data
import libs.dataset.normalization.operations as nop

class SupervisedDataset(torch.utils.data.Dataset):
    def __init__(self, cfgs, split, logger=None):
        self.cfgs = cfgs
        self.split = split
        self.logger = logger
        self.root = cfgs['dataset']['root']
        return
    
    def generate_pairs(self, synthetic=True):
        # sub-classes need to override this method to specify the inputs and
        # outputs
        self.input = None
        self.output = None
        self.total_data = 0
        return
    
    def normalize(self, statistics=None):
        """ 
        Normalize the (input, output) pairs with optional statistics.
        """
        if statistics is None:
            mean_in, std_in = nop.get_statistics_1d(self.input)
            mean_out, std_out = nop.get_statistics_1d(self.output)
            self.statistics = {'mean_in': mean_in,
                               'mean_out': mean_out,
                               'std_in': std_in,
                               'std_out': std_out
                               }
        else:
            mean_in, std_in = statistics['mean_in'], statistics['std_in']
            mean_out, std_out = statistics['mean_out'], statistics['std_out']
            self.statistics = statistics
        self.input = nop.normalize_1d(self.input, mean_in, std_in)
        self.output = nop.normalize_1d(self.output, mean_out, std_out)
        return
    
    def unnormalize(self, data, mean, std):
        return nop.unnormalize_1d(data, mean, std)
    
    def __len__(self):
        return self.total_data

    def __getitem__(self, idx):
        return self.input[idx], self.output[idx]

================================================
FILE: libs/dataset/normalization/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/dataset/normalization/operations.py
================================================
"""
Dataset normalization operations.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import numpy as np

def get_statistics_1d(data):
    """
    Compute statistics of 1D data.
    
    data of shape [num_sample, vector_length]
    """  
    assert len(data.shape) == 2
    mean = data.mean(axis=0, keepdims=True)
    std = data.std(axis=0, keepdims=True)
    return mean, std

def normalize_1d(data, mean, std, individual=False):
    """
    Normalizes 1D data with mean and standard deviation.
    
    data: dictionary where values are
    mean: np vector with the mean of the data
    std: np vector with the standard deviation of the data
    individual: whether to perform normalization independently for each input
    
    Returns
    data_out: normalized data
    """
    if individual:
        # this representation has the implicit assumption that the representation
        # is translational and scaling invariant
        num_data = len(data)
        data = data.reshape(num_data, -1, 2)
        mean_x = np.mean(data[:,:,0], axis=1).reshape(num_data, 1)
        std_x = np.std(data[:,:,0], axis=1)
        mean_y = np.mean(data[:,:,1], axis=1).reshape(num_data, 1)
        std_y = np.std(data[:,:,1], axis=1)
        denominator = (0.5 * (std_x + std_y)).reshape(num_data, 1)
        data[:,:,0] = (data[:,:,0] - mean_x)/denominator
        data[:,:,1] = (data[:,:,1] - mean_y)/denominator
        data_out = data.reshape(num_data, -1)
    else:
        data_out = (data - mean)/std
    return data_out

def unnormalize_1d(normalized_data, mean, std):
    orig_data = normalized_data*std + mean
    return orig_data

================================================
FILE: libs/logger/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/logger/logger.py
================================================
"""
Basic logging functions.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import logging
import os
import time

from libs.common import utils

initialized = False

def get_dirs(cfgs):
    """
    Prepare file directories for a logger object.
    """     
    root_output_dir = cfgs['dirs']['output']
    dataset_name = cfgs['dataset']['name']
    cfg_name = cfgs['name']
    final_output_dir = [root_output_dir, dataset_name]    
    final_output_dir = os.path.join(*final_output_dir)
    time_str = time.strftime('%Y-%m-%d %H:%M')
    log_file = '{}_{}.log'.format(cfg_name, time_str)
    final_log_file = os.path.join(final_output_dir, log_file)
    return final_output_dir, final_log_file

def get_logger(cfgs, head = '%(asctime)-15s %(message)s'):
    """
    Prepare a logger object.
    """     
    final_output_dir, final_log_file = get_dirs(cfgs)
    utils.make_dir(final_log_file)
    logging.basicConfig(filename=str(final_log_file), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if len(logger.handlers) == 1:    
        console = logging.StreamHandler()
        logging.getLogger('').addHandler(console)    
    return logger, final_output_dir

================================================
FILE: libs/loss/__init__.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Empty file.
"""




================================================
FILE: libs/loss/function.py
================================================
"""
Loss functions.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.spatial import distance_matrix

from libs.common.img_proc import soft_arg_max, appro_cr


loss_dict = {'mse': nn.MSELoss(reduction='mean'),
             'sl1': nn.SmoothL1Loss(reduction='mean'),
             'l1': nn.L1Loss(reduction='mean')
             }

class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight, meta=None):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss += 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                )
            else:
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

        return loss / num_joints

def get_comp_dict(spec_list = ['None', 'None', 'None'], 
                  loss_weights = [1,1,1]
                  ):
    comp_dict = {}

    if spec_list[0] != 'None':
        comp_dict['hm'] = (loss_dict[spec_list[0]], loss_weights[0])
    if spec_list[1] != 'None':
        comp_dict['coor'] = (loss_dict[spec_list[1]], loss_weights[1])
    if spec_list[2] != 'None':
        comp_dict['cr'] = (loss_dict[spec_list[2]], loss_weights[2])     
    return comp_dict

class JointsCompositeLoss(nn.Module):
    """
    Loss function for 2d screen coordinate regression which consists of 
    multiple terms.
    """
    def __init__(self,
                 spec_list,
                 img_size,
                 hm_size,
                 loss_weights = [1,1,1],
                 target_cr = None,
                 cr_loss_thres = 0.15,
                 use_target_weight=False
                 ):
        """
        comp_dict specify the optional terms used in the loss computation, 
        which is specified with spec_list.
        loss for each component follows the format of [loss_type, weight],
        loss_type speficy the loss type for each component (e.g. L1 or L2) while
        weight gives the weight for this component.
        
        hm: a supervised loss defined with a heatmap target
        coor: a supervised loss defined with 2D coordinates
        cr: a self-supervised loss defined with prior cross-ratio
        """
        super(JointsCompositeLoss, self).__init__()
        self.comp_dict = get_comp_dict(spec_list, loss_weights)
        self.img_size = img_size
        self.hm_size = hm_size
        self.target_cr = target_cr
        self.use_target_weight = use_target_weight
        self.apply_cr_loss = False
        self.cr_loss_thres = cr_loss_thres

    def calc_hm_loss(self, output, target):
        """
        Heatmap loss which corresponds to L_{hm} in the paper.
        
        output: predicted heatmaps of shape [N, K, H, W]
        target: ground truth heatmaps of shape [N, K, H, W]
        """        
        batch_size = output.size(0)
        num_parts = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_parts, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_parts, -1)).split(1, 1)
        loss = 0
        for idx in range(num_parts):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            loss += 0.5 * self.comp_dict['hm'][0](heatmap_pred, heatmap_gt)        
        return loss / num_parts
    
    def calc_cross_ratio_loss(self, pred_coor, target_cr, mask):
        """
        Cross-ratio loss which corresponds to L_{cr} in the paper.
        
        pred_coor: predicted local coordinates
        target_cr: ground truth cross ratio
        """  
        assert hasattr(self, 'cr_indices')
        # this indices is assumed to be initialized by the user
        loss = 0
        mask = mask.to(pred_coor.device)
        if mask.sum() == 0:
            return loss
        for sample_idx in range(len(pred_coor)):
            for line_idx in range(len(self.cr_indices)):
                if mask[sample_idx][line_idx] == 0:
                    continue
                # predicted cross-ratio square
                pred_cr_sqr = appro_cr(pred_coor[sample_idx][self.cr_indices[line_idx]])
                # normalize the predicted cross-ratio square
                pred_cr_sqr /= target_cr**2
                line_loss = self.comp_dict['cr'][0](pred_cr_sqr, torch.ones(1).to(pred_cr_sqr.device))
                loss += line_loss * mask[sample_idx][line_idx][0]
        return loss/mask.sum()
    
    def get_cr_mask(self, coordinates, threshold = 0.15):
        """
        Mask some edges out when computing the cross-ratio loss.
        Ignore the fore-shortened edges since they will produce large and 
        unstable gradient.
        """          
        assert hasattr(self, 'cr_indices')
        mask = torch.zeros(coordinates.shape[0], len(self.cr_indices), 1)
        for sample_idx in range(len(coordinates)):
            for line_idx in range(len(self.cr_indices)):
                pts = coordinates[sample_idx][self.cr_indices[line_idx]]
                dm = distance_matrix(pts, pts)
                minval = np.min(dm[np.nonzero(dm)])
                if minval > threshold:
                    mask[sample_idx][line_idx] = 1.0
        return mask
    
    def calc_colinear_loss(self):
        # DEPRECATED
        return 0.
    
    def calc_coor_loss(self, coordinates_pred, coordinates_gt):
        """
        Coordinate loss which corresponds to L_{2d} in the paper.
        coordinates_pred: [N, K, 2]
        coordinates_gt: [N, K, 2]
        """  
        coordinates_gt[:, :, 0] /= self.img_size[0]
        coordinates_gt[:, :, 1] /= self.img_size[1]   
        loss = self.comp_dict['coor'][0](coordinates_pred, coordinates_gt) 
        return loss
    
    def forward(self, output, target, target_weight=None, meta=None):
        """
        Loss evaluation.
        Output is in the format of (heatmaps, coordinates) where coordinates
        is optional.
        target refers to the ground truth heatmaps.
        """  
        if type(output) is tuple:
            heatmaps_pred, coordinates_pred = output
        else:
            heatmaps_pred, coordinates_pred = output, None
        total_loss = 0
        if 'hm' in self.comp_dict:
            # some heatmaps map be produced by unlabeled data
            if len(heatmaps_pred) != len(target):
                heatmaps_pred = heatmaps_pred[:len(target)]
            total_loss += self.calc_hm_loss(heatmaps_pred, target) * self.comp_dict['hm'][1]
        if 'coor' in self.comp_dict:
            coordinates_gt = meta['transformed_joints'][:, :, :2].astype(np.float32)
            coordinates_gt = torch.from_numpy(coordinates_gt).cuda()           
            if coordinates_pred == None:
                coordinates_pred, max_vals = soft_arg_max(heatmaps_pred)
                coordinates_pred[:, :, 0] /= self.hm_size[1]
                coordinates_pred[:, :, 1] /= self.hm_size[0]     
            if len(coordinates_pred) != len(coordinates_gt):
                coordinates_pred_fs = coordinates_pred[:len(coordinates_gt)]
            else:
                coordinates_pred_fs = coordinates_pred
            total_loss += self.calc_coor_loss(coordinates_pred_fs, coordinates_gt) * self.comp_dict['coor'][1] 
        if 'cr' in self.comp_dict and self.comp_dict['cr'][1] != "None" and self.apply_cr_loss:
            cr_loss_mask = self.get_cr_mask(coordinates_pred.clone().detach().data.cpu().numpy(), self.cr_loss_thres)
            total_loss += self.calc_cross_ratio_loss(coordinates_pred, self.target_cr, cr_loss_mask) * self.comp_dict['cr'][1]
        return total_loss
    
class MSELoss1D(nn.Module):
    """
    Mean squared error loss.
    """     
    def __init__(self, use_target_weight=False, reduction='mean'):
        super(MSELoss1D, self).__init__()
        self.criterion = nn.MSELoss(reduction=reduction)
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight=None, meta=None):
        loss = self.criterion(output, target)
        return loss
    
class SmoothL1Loss1D(nn.Module):
    """
    Smooth L1 loss.
    """
    def __init__(self, use_target_weight=False):
        super(SmoothL1Loss1D, self).__init__()
        self.criterion = nn.SmoothL1Loss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight=None, meta=None):
        loss = self.criterion(output, target)
        return loss

class DecoupledSL1Loss(nn.Module):
    # DEPRECATED
    def __init__(self, use_target_weight=None):
        super(DecoupledSL1Loss, self).__init__()
        self.criterion = F.smooth_l1_loss

    def forward(self, output, target, target_weight=None):
        # balance the loss for translation and rotation regression
        loss_center = self.criterion(output[:, :3], target[:, :3], reduction='mean')
        loss_else = self.criterion(output[:, 3:], target[:, 3:], reduction='mean')
        return loss_center + loss_else
    
class JointsOHKMMSELoss(nn.Module):
    # DEPRECATED
    def __init__(self, use_target_weight, topk=8):
        super(JointsOHKMMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight
        self.topk = topk

    def ohkm(self, loss):
        ohkm_loss = 0.
        for i in range(loss.size()[0]):
            sub_loss = loss[i]
            topk_val, topk_idx = torch.topk(
                sub_loss, k=self.topk, dim=0, sorted=False
            )
            tmp_loss = torch.gather(sub_loss, 0, topk_idx)
            ohkm_loss += torch.sum(tmp_loss) / self.topk
        ohkm_loss /= loss.size()[0]
        return ohkm_loss

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)

        loss = []
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss.append(0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                ))
            else:
                loss.append(
                    0.5 * self.criterion(heatmap_pred, heatmap_gt)
                )

        loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
        loss = torch.cat(loss, dim=1)

        return self.ohkm(loss)

class WingLoss(nn.Module):
    # DEPRECATED
    def __init__(self, use_target_weight, width=5, curvature=0.5, image_size=(384, 288)):
        super(WingLoss, self).__init__()
        self.width = width
        self.curvature = curvature
        self.C = self.width - self.width * np.log(1 + self.width / self.curvature)
        self.image_size = image_size
        
    def forward(self, output, target, target_weight):
        prediction, _ = soft_arg_max(output)
        # normalize the coordinates to 0-1
        prediction[:, :, 0] /= self.image_size[1]
        prediction[:, :, 1] /= self.image_size[0]
        target[:, :, 0] /= self.image_size[1]
        target[:, :, 1] /= self.image_size[0]  
        diff = target - prediction
        diff_abs = diff.abs()
        loss = diff_abs.clone()

        idx_smaller = diff_abs < self.width
        idx_bigger = diff_abs >= self.width

        loss[idx_smaller] = self.width * torch.log(1 + diff_abs[idx_smaller] / self.curvature)
        loss[idx_bigger]  = loss[idx_bigger] - self.C
        loss = loss.mean()
        return loss

================================================
FILE: libs/metric/criterions.py
================================================
"""
Metric functions used for validation.

Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""

import libs.common.transformation as ltr
import libs.common.img_proc as lip
from libs.common.transformation import compute_similarity_transform

import numpy as np
import torch
from scipy.spatial.transform import Rotation

# threshold for percentage of correct key-points (PCK)
PCK_THRES = np.array([0.1, 0.2, 0.3])

def get_distance(gt, pred):
    """
    2D Euclidean distance of two groups of points with visibility considered. 
    
    gt: [n_joints, 2 or 3]
    pred: [n_joints, 2]
    """    
    if gt.shape[1] == 2:
        sqerr = (gt - pred)**2
        sqerr = sqerr.sum(axis = 1)
        dist_list = list(np.sqrt(sqerr))
    elif gt.shape[1] == 3:
        dist_list = []
        sqerr = (gt[:, :2] - pred)**2
        sqerr = sqerr.sum(axis = 1)
        indices = np.nonzero(gt[:, 2])[0]
        dist_list = list(np.sqrt(sqerr[indices]))        
    else:
        raise ValueError('Array shape not supported.')
    return dist_list

def get_angle_error(pred, meta_data, cfgs=None):
    """
    Compute error for angle prediction.
    """    
    if not isinstance(pred, np.ndarray):
        pred = pred.data.cpu().numpy()    
    angles_pred = np.arctan2(pred[:,1], pred[:,0])
    angles_gt = meta_data['angles_gt']
    dif = np.abs(angles_gt - angles_pred) * 180 / np.pi
    # add or minus 2*pi
    indices = dif > 180
    dif[indices] = 360 - dif[indices]
    cnt = len(pred)
    avg_acc = dif.sum()/cnt
    others = None
    return avg_acc, cnt, others

def get_PCK(pred, gt):
    """
    Get percentage of correct key-points
    """
    distance = np.array(get_distance(gt, pred))
    denominator = (gt[:, 1].max() - gt[:, 1].min()) * 1/3
    correct_cnt = np.zeros((len(PCK_THRES)))
    for idx, thres in enumerate(PCK_THRES):
        correct_cnt[idx] = (distance < thres * denominator).sum()
    return correct_cnt

def get_distance_src(output,
                     meta_data,
                     cfgs=None,
                     image_size = (256.0, 256.0),
                     arg_max='hard'
                     ):
    """
    From predicted heatmaps, obtain local coordinates (\phi_l in the paper) 
    and transform them back to the source images based on metadata. 
    Error is then evaluated on the source image for the screen coordinates 
    (\phi_g in the paper).
    """
    # the error is reported as distance in terms of pixels in the source image
    if type(output) is tuple:
        pred, max_vals = output[1].data.cpu().numpy(), None
    elif isinstance(output, np.ndarray) and arg_max == 'soft':
        pred, max_vals = lip.soft_arg_max_np(output)
    elif isinstance(output, torch.Tensor) and arg_max == 'soft': 
        pred, max_vals = lip.soft_arg_max(output)
    elif isinstance(output, np.ndarray) or isinstance(output, torch.Tensor) and arg_max == 'hard':
        if not isinstance(output, np.ndarray):
            output = output.data.cpu().numpy()        
        pred, max_vals = lip.get_max_preds(output)
    else:
        raise NotImplementedError
    image_size = image_size if cfgs is None else cfgs['heatmapModel']['input_size']
    width, height = image_size
    # multiply by down-sample ratio
    if not isinstance(pred, np.ndarray):
        pred = pred.data.cpu().numpy()
    if (max_vals is not None) and (not isinstance(max_vals, np.ndarray)):
        max_vals = max_vals.data.cpu().numpy()
    # the coordinates need to be rescaled for different cases
    if type(output) is tuple:
        pred *= np.array(image_size).reshape(1, 1, 2)
    else:
        pred *= image_size[0] / output.shape[3]
    # inverse transform and compare pixel didstance
    centers, scales = meta_data['center'], meta_data['scale']
    # some predictions are generated for unlabeled data
    if len(pred) != len(centers):
        pred_used = pred[:len(centers)]
    else:
        pred_used = pred
    if 'rotation' in meta_data:
        rots = meta_data['rotation']
    else:
        rots = [0. for i in range(len(centers))]
    joints_original_batch = meta_data['original_joints']
    distance_list = []
    correct_cnt_sum = np.zeros((len(PCK_THRES)))
    all_src_coordinates = []
    for sample_idx in range(len(pred_used)):
        trans_inv = lip.get_affine_transform(centers[sample_idx], 
                                             scales[sample_idx], 
                                             rots[sample_idx], 
                                             (height, width), 
                                             inv=1
                                             )
        joints_original = joints_original_batch[sample_idx]        
        pred_src_coordinates = lip.affine_transform_modified(pred_used[sample_idx], 
                                                             trans_inv
                                                             ) 
        all_src_coordinates.append(pred_src_coordinates.reshape(1, len(pred_src_coordinates), 2))
        distance_list += get_distance(joints_original, pred_src_coordinates)
        correct_cnt_sum += get_PCK(pred_src_coordinates, joints_original)
    cnt = len(distance_list)
    avg_acc = sum(distance_list) / cnt
    others = {
        'src_coord': np.concatenate(all_src_coordinates, axis=0), # screen coordinates
        'joints_pred': pred, # predicted local coordinates
        'max_vals': max_vals, 
        'correct_cnt': correct_cnt_sum,
        'PCK_batch': correct_cnt_sum / cnt
        }
    return avg_acc, cnt, others

class AngleError():
    """
    Angle error in degrees. 
    """  
    def __init__(self, cfgs, num_joints=None):
        self.name = 'Angle error in degrees'
        self.num_joints = num_joints
        self.count = 0
        self.mean = 0.
        return
  
    def update(self, prediction, meta_data, ground_truth=None, logger=None):
        """
        the prediction and transformation parameters in meta_data are used.
        """    
        avg_acc, cnt, others = get_angle_error(prediction, meta_data)
        self.mean = (self.mean * self.count + cnt * avg_acc) / (self.count + cnt)
        self.count += cnt
        return 
    
    def report(self, logger):
        msg = 'Error type: {error_type:s}\t' \
              'Error: {Error}\t'.format(
                      error_type = self.name,
                      Error = self.mean)     
        logger.info(msg)        
        return

class JointDistance2DSIP():
    """
    Joint distance error evaluated for screen coordinates in the source image plane (SIP). 
    """  
    def __init__(self, cfgs, num_joints=None):
        self.name = 'Joint distance in the source image plane'
        if num_joints is not None:
            self.num_joints = num_joints
        else:
            self.num_joints = cfgs['heatmapModel']['num_joints']
        self.image_size = cfgs['heatmapModel']['input_size']
        if 'arg_max' in cfgs['testing_settings']:
            self.arg_max = cfgs['testing_settings']['arg_max']
        else:
            self.arg_max = None
            self.count = 0
            self.mean = 0.
            self.PCK_counts = np.zeros(len(PCK_THRES))
        return
  
    def update(self, prediction, meta_data, ground_truth=None, logger=None):
        """
        Update statistics for a batch.
        The prediction and transformation parameters in meta_data are used.
        """    
        avg_acc, cnt, others = get_distance_src(prediction, 
                                                meta_data,
                                                arg_max=self.arg_max,
                                                image_size=self.image_size
                                                )       
        self.mean = (self.mean * self.count + cnt * avg_acc) / (self.count + cnt)
        self.count += cnt
        self.PCK_counts += others['correct_cnt']
        return 
    
    def report(self, logger):
        """
        Report final evaluation results.
        """  
        logger.info("Ealuaton Results:")
        msg = 'Error type: {error_type:s}\t' \
              'MPJPE: {MPJPE}\t'.format(error_type = self.name, 
                                        MPJPE = self.mean
                                        )     
        logger.info(msg)        
        for idx, value in enumerate(self.PCK_counts):
            PCK = value / self.count
            logger.info('PCK at threshold {:.2f}: {:.3f}'.format(PCK_THRES[idx], PCK))        
        return

def update_statistics(self, update, num_data, name_str):
    """
    Update error statistics for a data batch.
    """ 
    old_count = getattr(self, 'count'+name_str)
    old_mean = getattr(self, 'mean'+name_str)
    old_max = getattr(self, 'max'+name_str)
    old_min = getattr(self, 'min'+name_str)
    new_mean = (old_count * old_mean + np.sum(update, axis=0)) / (old_count + num_data) 
    new_count = old_count + num_data
    new_max = np.maximum(old_max, update.max(axis=0))
    new_min = np.minimum(old_min, update.min(axis=0))
    setattr(self, 'mean'+name_str, new_mean)
    setattr(self, 'count'+name_str, new_count)
    setattr(self, 'max'+name_str, new_max)
    setattr(self, 'min'+name_str, new_min)    
    return

def update_rotation_error(self, 
                          prediction, 
                          ground_truth, 
                          meta_data=None, 
                          logger=None,
                          name_str='',
                          style='euler'
                          ):
    """
    Get rotation error between two 3D point clouds. 
    """    
    num_data = len(prediction)
    prediction = prediction.reshape(num_data, -1, 3)
    ground_truth = ground_truth.reshape(num_data, -1, 3)
    if style == 'euler':
        results = -np.ones((num_data, 3))
    for data_idx in range(num_data):
        R, T = ltr.compute_rigid_transform(prediction[data_idx].T, 
                                           ground_truth[data_idx].T
                                           )
        if style == 'euler':
            results[data_idx] = np.abs(Rotation.from_matrix(R).as_euler('xyz', 
                                                                        degrees=True
                                                                        )
                                       )
        else:
            raise NotImplementedError
    update_statistics(self, results, num_data, name_str)
    return

def update_joints_3d_error(self, 
                           prediction, 
                           ground_truth, 
                           meta_data=None, 
                           logger=None,
                           name_str='',
                           style='direct'
                           ):
    """
    Get distance error between prediction and ground truth.
    """
    ground_truth = ground_truth.reshape(len(ground_truth), -1, 3)
    prediction = prediction.reshape(len(prediction), -1, 3)
    num_joints = prediction.shape[1]
    if style == 'procrustes':
        # Apply procrustes alignment if asked to do so
        for j in range(len(prediction)):
            gt  = ground_truth[j]
            out = prediction[j]
            _, Z, T, b, c = compute_similarity_transform(gt, out, compute_optimal_scale=True)
            out = (b * out.dot(T)) + c
            prediction[j] = np.reshape(out, [num_joints, 3])
    sqerr = (ground_truth - prediction)**2 
    distance = np.sqrt(np.sum(sqerr, axis=2))        
    num_data = len(prediction)
    update_statistics(self, distance, num_data, name_str)    
    # provide detailed L1 errors if there is only one joint
    if num_joints == 1:
        error_xyz = np.abs(ground_truth - prediction)
        update_statistics(self, error_xyz, num_data, name_str + '_xyz')
    return    

class RotationError3D():
    """
    Helper class for recording rotation estimation error.
    """
    def __init__(self, cfgs):
        self.name = 'Rotation error'
        self.style = cfgs['metrics']['R3D']['style']
        self.count = 0
        if self.style == 'euler':
            self.mean = np.zeros((3))
            self.max = -np.ones((3))
            self.min = np.ones((3))*1e16
        return
    
    def update(self, prediction, ground_truth, meta_data=None, logger=None):
        """
        get rotation error between two point clouds 
        """    
        update_rotation_error(self, 
                              prediction, 
                              ground_truth, 
                              meta_data=meta_data, 
                              logger=logger,
                              style=self.style
                              )
        return 
    
    def report(self, logger):
        msg = 'Error type: {error_type:s}\t' \
              'Mean error: {mean_error}\t' \
              'Max error: {max_error}\t' \
              'Min error: {min_error}\t'.format(
                      error_type = self.name,
                      mean_error= self.mean, 
                      max_error= self.max,
                      min_error= self.min
                      )     
        logger.info(msg)        
        return
    
class JointDistance3D():
    """
    Helper class for recording joint distance error.
    """
    def __init__(self, cfgs):
        self.name = 'Joint distance'
        self.style = cfgs['metrics']['JD3D']['style']
        self.num_joints = int(cfgs['FCModel']['output_size']/3)
        self.count = 0
        if self.style in ['direct', 'procrustes']:
            self.mean = np.zeros((self.num_joints))
            self.max = -np.ones((self.num_joints))
            self.min = np.ones((self.num_joints))*1e16
        else:
            raise NotImplementedError
        return
  
    def update(self, prediction, ground_truth, meta_data=None, logger=None):
        """
        get Euclidean distance between two point clouds 
        """    
        update_joints_3d_error(self, 
                               prediction,
                               ground_truth,
                               meta_data=meta_data,
                               logger=logger,
                               name_str='',
                               style=self.style
                               )        
        return 
    
    def report(self, logger):
        MPJPE = self.mean.sum() / self.num_joints
        msg = 'Error type: {error_type:s}\t' \
              'MPJPE: {MPJPE}\t' \
              'Mean error for each joint: {mean_error}\t' \
              'Max error for each joint: {max_error}\t' \
              'Min error for each joint: {min_error}\t'.format(
                      error_type = self.name,
                      MPJPE = MPJPE,
                      mean_error= self.mean, 
                      max_error= self.max,
                      min_error= self.min
                      )     
        logger.info(msg)        
        return

class RError3D():
    def __init__(self, cfgs, num_joints):
        """
        Relative shape error
        The point cloud should have a format [shape_relative_to_root]
        """           
        self.name = 'RError3D'
        self.T_style = cfgs['metrics']['R3D']['T_style']
        self.R_style = cfgs['metrics']['R3D']['R_style']
        if cfgs['dataset']['3d_kpt_sample_style'] == 'bbox9': 
            self.num_joints = num_joints - 1 # discount the root joint
        else:
            raise NotImplementedError
        self.count_rT = self.count_R = 0
        # translation error of the shape relative to the root
        self.mean_rT = np.zeros((self.num_joints))
        self.max_rT = -np.ones((self.num_joints))
        self.min_rT = np.ones((self.num_joints))*1e16            
        # relative rotation between the ground truth shape and predicted shape
        self.mean_R = np.zeros((3))
        self.max_R = -np.ones((3))
        self.min_R = np.ones((3))*1e16            
        return
  
    def update(self, prediction, ground_truth, meta_data=None, logger=None):
        update_joints_3d_error(self, 
                               prediction=prediction,
                               ground_truth=ground_truth,
                               meta_data=meta_data,
                               logger=logger,
                               name_str='_rT',
                               style=self.T_style
                               )
        update_rotation_error(self,
                              prediction=prediction,
                              ground_truth=ground_truth,
                              meta_data=meta_data,
                              logger=logger,
                              name_str='_R',
                              style=self.R_style
                              )        
        return 
    
    def report(self, logger):
        MPJPE = self.mean_rT.sum() / self.num_joints
        msg = 'Error type: {error_type:s}\t' \
              'MPJPE of the shape relative to the root:\t' \
              'MPJPE: {MPJPE}\t' \
              'Rotation error of the shape relative to the root:\t' \
              'Mean error: {mean_R}\t' \
              'Max error: {max_R}\t' \
              'Min error: {min_R}\t'.format(
                  error_type = self.name,
                  MPJPE = MPJPE,
                  mean_R = self.mean_R,
                  max_R = self.max_R,
                  min_R = self.min_R
                  )     
        logger.info(msg)        
        return
    
class RTError3D():
    def __init__(self, cfgs, num_joints):
        """
        Rotation and translation error combined.
        The point cloud should have a format [root, shape_relative_to_root]
        """           
        self.name = 'RTError3D'
        self.T_style = cfgs['metrics']['RTError3D']['T_style']
        self.R_style = cfgs['metrics']['RTError3D']['R_style']
        if cfgs['dataset']['3d_kpt_sample_style'] == 'bbox9': 
            self.num_joints = num_joints - 1 # discount the root joint
        else:
            raise NotImplementedError
        self.count_T = self.count_T_xyz = self.count_rT = self.count_R = 0
        if self.T_style in ['direct', 'procrustes']:
            # translation error of the root vector
            self.mean_T = np.zeros((1))
            # L1 error for each component
            self.mean_T_xyz = np.zeros((3))
            self.max_T = -np.ones((1))
            self.max_T_xyz = -np.ones((3))
            self.min_T = np.ones((1))*1e16
            self.min_T_xyz = np.ones((3))*1e16
            # translation error of the shape relative to the root
            self.mean_rT = np.zeros((self.num_joints))
            self.max_rT = -np.ones((self.num_joints))
            self.min_rT = np.ones((self.num_joints))*1e16            
        else:
            raise NotImplementedError
        # relative rotation between the ground truth shape and predicted shape
        self.mean_R = np.zeros((3))
        self.max_R = -np.ones((3))
        self.min_R = np.ones((3))*1e16            
        return
  
    def update(self, prediction, ground_truth, meta_data=None, logger=None):
        update_joints_3d_error(self, 
                               prediction=prediction[:, :3],
                               ground_truth=ground_truth[:, :3],
                               meta_data=meta_data,
                               logger=logger,
                               name_str='_T',
                               style=self.T_style
                               )
        update_joints_3d_error(self, 
                               prediction=prediction[:, 3:],
                               ground_truth=ground_truth[:, 3:],
                               meta_data=meta_data,
                               logger=logger,
                               name_str='_rT',
                               style=self.T_style
                               )
        update_rotation_error(self,
                              prediction=prediction[:, 3:],
                              ground_truth=ground_truth[:, 3:],
                              meta_data=meta_data,
                              logger=logger,
                              name_str='_R',
                              style=self.R_style
          
Download .txt
gitextract_zura5gg8/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── KITTI_inference:demo.yml
│   ├── KITTI_inference:test_submission.yml
│   ├── KITTI_train_IGRs.yml
│   ├── KITTI_train_IGRs_Ped.yml
│   └── KITTI_train_lifting.yml
├── docs/
│   ├── demo.md
│   ├── inference.md
│   ├── preparation.md
│   ├── spec-list.txt
│   └── training.md
├── libs/
│   ├── arguments/
│   │   ├── __init__.py
│   │   └── parse.py
│   ├── common/
│   │   ├── __init__.py
│   │   ├── format.py
│   │   ├── img_proc.py
│   │   ├── transformation.py
│   │   └── utils.py
│   ├── dataset/
│   │   ├── KITTI/
│   │   │   ├── __init__.py
│   │   │   └── car_instance.py
│   │   ├── __init__.py
│   │   ├── basic/
│   │   │   ├── __init__.py
│   │   │   └── basic_classes.py
│   │   └── normalization/
│   │       ├── __init__.py
│   │       └── operations.py
│   ├── logger/
│   │   ├── __init__.py
│   │   └── logger.py
│   ├── loss/
│   │   ├── __init__.py
│   │   └── function.py
│   ├── metric/
│   │   └── criterions.py
│   ├── model/
│   │   ├── FCmodel.py
│   │   ├── __init__.py
│   │   ├── egonet.py
│   │   └── heatmapModel/
│   │       ├── __init__.py
│   │       ├── hrnet.py
│   │       └── resnet.py
│   ├── optimizer/
│   │   ├── __init__.py
│   │   └── optimizer.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── accuracy.py
│   │   └── trainer.py
│   └── visualization/
│       ├── __init__.py
│       ├── debug.py
│       ├── egonet_utils.py
│       └── points.py
└── tools/
    ├── inference.py
    ├── inference_legacy.py
    ├── kitti-eval/
    │   ├── README.md
    │   ├── evaluate_object_3d.cpp
    │   ├── evaluate_object_3d_offline.cpp
    │   ├── evaluate_object_3d_offline_r40.cpp
    │   └── mail.h
    ├── train_IGRs.py
    └── train_lifting.py
Download .txt
SYMBOL INDEX (428 symbols across 29 files)

FILE: libs/arguments/parse.py
  function read_yaml_file (line 11) | def read_yaml_file(path):
  function parse_args (line 22) | def parse_args():

FILE: libs/common/format.py
  function format_str_submission (line 11) | def format_str_submission(roll, pitch, yaw, x, y, z, score):
  function get_instance_str (line 25) | def get_instance_str(dic):
  function get_pred_str (line 44) | def get_pred_str(record):
  function save_txt_file (line 63) | def save_txt_file(img_path, prediction, params):

FILE: libs/common/img_proc.py
  function transform_preds (line 16) | def transform_preds(coords, center, scale, output_size):
  function get_affine_transform (line 26) | def get_affine_transform(center,
  function affine_transform (line 66) | def affine_transform(pt, t):
  function affine_transform_modified (line 71) | def affine_transform_modified(pts, t):
  function get_3rd_point (line 80) | def get_3rd_point(a, b):
  function get_dir (line 84) | def get_dir(src_point, rot_rad):
  function crop (line 93) | def crop(img, center, scale, output_size, rot=0):
  function simple_crop (line 107) | def simple_crop(input_image, center, crop_size):
  function np_random (line 137) | def np_random():
  function jitter_bbox_with_kpts (line 143) | def jitter_bbox_with_kpts(old_bbox, joints, parameters):
  function jitter_bbox_with_kpts_no_occlu (line 174) | def jitter_bbox_with_kpts_no_occlu(old_bbox, joints, parameters):
  function generate_xy_map (line 193) | def generate_xy_map(bbox, resolution, global_size):
  function crop_single_instance (line 213) | def crop_single_instance(data_numpy, bbox, joints, parameters, pth_trans...
  function get_tensor_from_img (line 251) | def get_tensor_from_img(path,
  function generate_target (line 347) | def generate_target(joints, joints_vis, parameters):
  function resize_bbox (line 411) | def resize_bbox(left, top, right, bottom, target_ar=1.):
  function enlarge_bbox (line 437) | def enlarge_bbox(left, top, right, bottom, enlarge):
  function modify_bbox (line 453) | def modify_bbox(bbox, target_ar, enlarge=1.1):
  function resize_crop (line 461) | def resize_crop(crop_size, target_ar=None):
  function bbox2cs (line 478) | def bbox2cs(bbox):
  function cs2bbox (line 485) | def cs2bbox(center, size):
  function kpts2cs (line 495) | def kpts2cs(keypoints,
  function draw_bboxes (line 542) | def draw_bboxes(img_path, bboxes_dict, save_path=None):
  function imread_rgb (line 556) | def imread_rgb(img_path):
  function save_cropped_patches (line 564) | def save_cropped_patches(img_path,
  function get_max_preds (line 608) | def get_max_preds(batch_heatmaps):
  function soft_arg_max_np (line 639) | def soft_arg_max_np(batch_heatmaps):
  function soft_arg_max (line 678) | def soft_arg_max(batch_heatmaps):
  function appro_cr (line 709) | def appro_cr(coordinates):
  function to_npy (line 722) | def to_npy(tensor):

FILE: libs/common/transformation.py
  function move_to (line 11) | def move_to(points, xyz=np.zeros((1,3))):
  function world_to_camera_frame (line 16) | def world_to_camera_frame(P, R, T):
  function camera_to_world_frame (line 32) | def camera_to_world_frame(P, R, T):
  function compute_similarity_transform (line 48) | def compute_similarity_transform(X, Y, compute_optimal_scale=False):
  function compute_rigid_transform (line 99) | def compute_rigid_transform(X, Y, W=None, verbose=False):
  function procrustes_transform (line 136) | def procrustes_transform(X, Y):
  function pnp_refine (line 143) | def pnp_refine(prediction, observation, intrinsics, dist_coeffs):

FILE: libs/common/utils.py
  function make_dir (line 18) | def make_dir(name):
  function save_checkpoint (line 30) | def save_checkpoint(states, is_best, output_dir, filename='checkpoint.pt...
  function get_model_summary (line 35) | def get_model_summary(model, *input_tensors, item_length=26, verbose=Fal...
  class AverageMeter (line 149) | class AverageMeter(object):
    method __init__ (line 153) | def __init__(self):
    method reset (line 157) | def reset(self):
    method update (line 164) | def update(self, val, n=1, others=None):
    method print_content (line 178) | def print_content(self):

FILE: libs/dataset/KITTI/car_instance.py
  function get_cr_indices (line 99) | def get_cr_indices():
  class KITTI (line 121) | class KITTI(bc.SupervisedDataset):
    method __init__ (line 125) | def __init__(self, cfgs, split, logger, scale=1.0):
    method _get_image_path_list (line 164) | def _get_image_path_list(self):
    method _initialize_unlabeled_data (line 176) | def _initialize_unlabeled_data(self, cfgs):
    method _load_image_list (line 184) | def _load_image_list(self):
    method _check_precomputed_file (line 199) | def _check_precomputed_file(self, path, name):
    method _save_precomputed_file (line 211) | def _save_precomputed_file(self, data_dic, pre_computed_path, name):
    method _prepare_key_points_custom (line 221) | def _prepare_key_points_custom(self, style, interp_params, vis_thresh=...
    method _prepare_key_points (line 264) | def _prepare_key_points(self, cfgs):
    method _save_cropped_instances (line 273) | def _save_cropped_instances(self):
    method _prepare_2d_pose_annot (line 304) | def _prepare_2d_pose_annot(self, threshold=4):
    method _prepare_detection_records (line 348) | def _prepare_detection_records(self, save=False, threshold = 0.1):
    method gather_annotations (line 352) | def gather_annotations(self,
    method read_single_file (line 383) | def read_single_file(self,
    method read_predictions (line 459) | def read_predictions(self, path):
    method _get_data_parameters (line 480) | def _get_data_parameters(self, cfgs):
    method _set_paths (line 533) | def _set_paths(self):
    method project_3d_to_2d (line 557) | def project_3d_to_2d(self, points, K):
    method render_car (line 565) | def render_car(self, ax, K, obj_class, rot_y, locs, dimension, shift):
    method show_statistics (line 575) | def show_statistics(self):
    method augment_pose_vector (line 611) | def augment_pose_vector(self,
    method get_representation (line 646) | def get_representation(self, p2d, p3d, in_rep, out_rep):
    method get_input_output_size (line 688) | def get_input_output_size(self):
    method interpolate (line 705) | def interpolate(self,
    method construct_box_3d (line 730) | def construct_box_3d(self, l, h, w, interp_params):
    method get_cam_cord (line 749) | def get_cam_cord(self, cam_cord, shift, ids, pose_vecs, rot_xz=False):
    method csv_read_annot (line 792) | def csv_read_annot(self, file_path, fieldnames):
    method csv_read_calib (line 831) | def csv_read_calib(self, file_path):
    method load_annotations (line 845) | def load_annotations(self, label_path, calib_path, fieldnames=FIELDNAM...
    method add_visibility (line 855) | def add_visibility(self, joints, img_width=1242, img_height=375):
    method get_inlier_indices (line 870) | def get_inlier_indices(self, p_2d, threshold=0.3):
    method filter_outlier (line 881) | def filter_outlier(self, p_2d, p_3d, threshold=0.3):
    method get_img_size (line 894) | def get_img_size(self, path):
    method get_2d_3d_pair (line 902) | def get_2d_3d_pair(self,
    method show_annot (line 1012) | def show_annot(self,
    method _generate_2d_3d_paris (line 1051) | def _generate_2d_3d_paris(self):
    method generate_pairs (line 1088) | def generate_pairs(self):
    method visualize (line 1128) | def visualize(self, plot_num = 1, save_dir=None):
    method get_collate_fn (line 1138) | def get_collate_fn(self):
    method inference (line 1141) | def inference(self, flags=[True, True]):
    method extract_ss_sample (line 1145) | def extract_ss_sample(self, cnt):
    method prepare_ft_dict (line 1171) | def prepare_ft_dict(self, idx):
    method __getitem__ (line 1217) | def __getitem__(self, idx):
  function prepare_data (line 1321) | def prepare_data(cfgs, logger):
  function get_dataset (line 1332) | def get_dataset(cfgs, logger, split):
  function collate_dict (line 1335) | def collate_dict(dict_list):
  function length_limit (line 1344) | def length_limit(instances, targets, target_weights, meta):
  function my_collate_fn (line 1368) | def my_collate_fn(batch):

FILE: libs/dataset/basic/basic_classes.py
  class SupervisedDataset (line 10) | class SupervisedDataset(torch.utils.data.Dataset):
    method __init__ (line 11) | def __init__(self, cfgs, split, logger=None):
    method generate_pairs (line 18) | def generate_pairs(self, synthetic=True):
    method normalize (line 26) | def normalize(self, statistics=None):
    method unnormalize (line 46) | def unnormalize(self, data, mean, std):
    method __len__ (line 49) | def __len__(self):
    method __getitem__ (line 52) | def __getitem__(self, idx):

FILE: libs/dataset/normalization/operations.py
  function get_statistics_1d (line 10) | def get_statistics_1d(data):
  function normalize_1d (line 21) | def normalize_1d(data, mean, std, individual=False):
  function unnormalize_1d (line 50) | def unnormalize_1d(normalized_data, mean, std):

FILE: libs/logger/logger.py
  function get_dirs (line 16) | def get_dirs(cfgs):
  function get_logger (line 30) | def get_logger(cfgs, head = '%(asctime)-15s %(message)s'):

FILE: libs/loss/function.py
  class JointsMSELoss (line 22) | class JointsMSELoss(nn.Module):
    method __init__ (line 23) | def __init__(self, use_target_weight):
    method forward (line 28) | def forward(self, output, target, target_weight, meta=None):
  function get_comp_dict (line 48) | def get_comp_dict(spec_list = ['None', 'None', 'None'],
  class JointsCompositeLoss (line 61) | class JointsCompositeLoss(nn.Module):
    method __init__ (line 66) | def __init__(self,
    method calc_hm_loss (line 95) | def calc_hm_loss(self, output, target):
    method calc_cross_ratio_loss (line 113) | def calc_cross_ratio_loss(self, pred_coor, target_cr, mask):
    method get_cr_mask (line 138) | def get_cr_mask(self, coordinates, threshold = 0.15):
    method calc_colinear_loss (line 155) | def calc_colinear_loss(self):
    method calc_coor_loss (line 159) | def calc_coor_loss(self, coordinates_pred, coordinates_gt):
    method forward (line 170) | def forward(self, output, target, target_weight=None, meta=None):
  class MSELoss1D (line 204) | class MSELoss1D(nn.Module):
    method __init__ (line 208) | def __init__(self, use_target_weight=False, reduction='mean'):
    method forward (line 213) | def forward(self, output, target, target_weight=None, meta=None):
  class SmoothL1Loss1D (line 217) | class SmoothL1Loss1D(nn.Module):
    method __init__ (line 221) | def __init__(self, use_target_weight=False):
    method forward (line 226) | def forward(self, output, target, target_weight=None, meta=None):
  class DecoupledSL1Loss (line 230) | class DecoupledSL1Loss(nn.Module):
    method __init__ (line 232) | def __init__(self, use_target_weight=None):
    method forward (line 236) | def forward(self, output, target, target_weight=None):
  class JointsOHKMMSELoss (line 242) | class JointsOHKMMSELoss(nn.Module):
    method __init__ (line 244) | def __init__(self, use_target_weight, topk=8):
    method ohkm (line 250) | def ohkm(self, loss):
    method forward (line 262) | def forward(self, output, target, target_weight):
  class WingLoss (line 287) | class WingLoss(nn.Module):
    method __init__ (line 289) | def __init__(self, use_target_weight, width=5, curvature=0.5, image_si...
    method forward (line 296) | def forward(self, output, target, target_weight):

FILE: libs/metric/criterions.py
  function get_distance (line 19) | def get_distance(gt, pred):
  function get_angle_error (line 40) | def get_angle_error(pred, meta_data, cfgs=None):
  function get_PCK (line 57) | def get_PCK(pred, gt):
  function get_distance_src (line 68) | def get_distance_src(output,
  class AngleError (line 145) | class AngleError():
    method __init__ (line 149) | def __init__(self, cfgs, num_joints=None):
    method update (line 156) | def update(self, prediction, meta_data, ground_truth=None, logger=None):
    method report (line 165) | def report(self, logger):
  class JointDistance2DSIP (line 173) | class JointDistance2DSIP():
    method __init__ (line 177) | def __init__(self, cfgs, num_joints=None):
    method update (line 193) | def update(self, prediction, meta_data, ground_truth=None, logger=None):
    method report (line 208) | def report(self, logger):
  function update_statistics (line 223) | def update_statistics(self, update, num_data, name_str):
  function update_rotation_error (line 241) | def update_rotation_error(self,
  function update_joints_3d_error (line 271) | def update_joints_3d_error(self,
  class RotationError3D (line 303) | class RotationError3D():
    method __init__ (line 307) | def __init__(self, cfgs):
    method update (line 317) | def update(self, prediction, ground_truth, meta_data=None, logger=None):
    method report (line 330) | def report(self, logger):
  class JointDistance3D (line 343) | class JointDistance3D():
    method __init__ (line 347) | def __init__(self, cfgs):
    method update (line 360) | def update(self, prediction, ground_truth, meta_data=None, logger=None):
    method report (line 374) | def report(self, logger):
  class RError3D (line 390) | class RError3D():
    method __init__ (line 391) | def __init__(self, cfgs, num_joints):
    method update (line 414) | def update(self, prediction, ground_truth, meta_data=None, logger=None):
    method report (line 433) | def report(self, logger):
  class RTError3D (line 451) | class RTError3D():
    method __init__ (line 452) | def __init__(self, cfgs, num_joints):
    method update (line 486) | def update(self, prediction, ground_truth, meta_data=None, logger=None):
    method report (line 513) | def report(self, logger):
  class Evaluator (line 540) | class Evaluator():
    method __init__ (line 544) | def __init__(self, metrics, cfgs=None, num_joints=9):
    method update (line 553) | def update(self,
    method report (line 570) | def report(self, logger):

FILE: libs/model/FCmodel.py
  class ResidualBlock (line 9) | class ResidualBlock(nn.Module):
    method __init__ (line 10) | def __init__(self,
    method forward (line 33) | def forward(self, x):
  class FCModel (line 45) | class FCModel(nn.Module):
    method __init__ (line 46) | def __init__(self,
    method forward (line 92) | def forward(self, x):
    method get_representation (line 97) | def get_representation(self, x):
  function get_fc_model (line 107) | def get_fc_model(stage_id,
  function get_cascade (line 123) | def get_cascade():

FILE: libs/model/egonet.py
  class EgoNet (line 28) | class EgoNet(nn.Module):
    method __init__ (line 29) | def __init__(self,
    method crop_single_instance (line 68) | def crop_single_instance(self,
    method load_cv2 (line 97) | def load_cv2(self, path, rgb=True):
    method crop_instances (line 105) | def crop_instances(self,
    method add_orientation_arrow (line 157) | def add_orientation_arrow(self, record):
    method write_annot_dict (line 181) | def write_annot_dict(self, annot_dict, records):
    method get_observation_angle_trans (line 203) | def get_observation_angle_trans(self, euler_angles, translations):
    method get_observation_angle_proj (line 219) | def get_observation_angle_proj(self, euler_angles, kpts, K):
    method get_template (line 238) | def get_template(self, prediction, interp_coef=[0.332, 0.667]):
    method kpts_to_euler (line 265) | def kpts_to_euler(self, template, prediction):
    method get_6d_rep (line 279) | def get_6d_rep(self, predictions, ax=None, color="black"):
    method gather_lifting_results (line 297) | def gather_lifting_results(self,
    method plot_one_image (line 341) | def plot_one_image(self,
    method post_process (line 385) | def post_process(self,
    method new_img_dict (line 410) | def new_img_dict(self):
    method get_keypoints (line 424) | def get_keypoints(self,
    method lift_2d_to_3d (line 469) | def lift_2d_to_3d(self, records, cuda=True):
    method forward (line 488) | def forward(self, annot_dict):

FILE: libs/model/heatmapModel/hrnet.py
  function conv3x3 (line 24) | def conv3x3(in_planes, out_planes, stride=1):
  function basicdownsample (line 29) | def basicdownsample(in_planes, out_planes):
  class BasicLinearModule (line 44) | class BasicLinearModule(nn.Module):
    method __init__ (line 45) | def __init__(self, in_channels, out_channels, mid_channels=512):
    method forward (line 53) | def forward(self, x):
  class BasicBlock (line 63) | class BasicBlock(nn.Module):
    method __init__ (line 66) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 76) | def forward(self, x):
  class Bottleneck (line 95) | class Bottleneck(nn.Module):
    method __init__ (line 98) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 113) | def forward(self, x):
  class HighResolutionModule (line 136) | class HighResolutionModule(nn.Module):
    method __init__ (line 137) | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
    method _check_branches (line 154) | def _check_branches(self, num_branches, blocks, num_blocks,
    method _make_one_branch (line 174) | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
    method _make_branches (line 212) | def _make_branches(self, num_branches, block, num_blocks, num_channels):
    method _make_fuse_layers (line 222) | def _make_fuse_layers(self):
    method get_num_inchannels (line 279) | def get_num_inchannels(self):
    method forward (line 282) | def forward(self, x):
  class PoseHighResolutionNet (line 309) | class PoseHighResolutionNet(nn.Module):
    method __init__ (line 311) | def __init__(self, cfgs, **kwargs):
    method _make_transition_layer (line 471) | def _make_transition_layer(
    method _make_layer (line 512) | def _make_layer(self, block, planes, blocks, stride=1):
    method _make_stage (line 531) | def _make_stage(self, layer_config, num_inchannels,
    method forward (line 563) | def forward(self, x):
    method init_weights (line 616) | def init_weights(self, pretrained=''):
    method modify_input_channel (line 649) | def modify_input_channel(self, num_channels):
    method load_my_state_dict (line 661) | def load_my_state_dict(self, state_dict):
  function is_freezed (line 669) | def is_freezed(name, freeze_names):
  function get_pose_net (line 675) | def get_pose_net(cfgs, is_train, **kwargs):

FILE: libs/model/heatmapModel/resnet.py
  function conv3x3 (line 22) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 30) | class BasicBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 43) | def forward(self, x):
  class Bottleneck (line 62) | class Bottleneck(nn.Module):
    method __init__ (line 65) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 80) | def forward(self, x):
  class PoseResNet (line 103) | class PoseResNet(nn.Module):
    method __init__ (line 105) | def __init__(self, block, layers, cfg, **kwargs):
    method _make_layer (line 136) | def _make_layer(self, block, planes, blocks, stride=1):
    method _get_deconv_cfg (line 153) | def _get_deconv_cfg(self, deconv_kernel, index):
    method _make_deconv_layer (line 166) | def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
    method forward (line 193) | def forward(self, x):
    method init_weights (line 209) | def init_weights(self, pretrained=''):
  function get_pose_net (line 261) | def get_pose_net(cfg, is_train, **kwargs):

FILE: libs/optimizer/optimizer.py
  function prepare_optim (line 9) | def prepare_optim(model, cfgs):

FILE: libs/trainer/accuracy.py
  function get_distance (line 10) | def get_distance(gt, pred):
  function accuracy_pixel (line 27) | def accuracy_pixel(output,

FILE: libs/trainer/trainer.py
  function train_cascade (line 25) | def train_cascade(train_dataset, valid_dataset, cfgs, logger):
  function evaluate_cascade (line 73) | def evaluate_cascade(cascade,
  function get_loader (line 113) | def get_loader(dataset, cfgs, split, collate_fn=None):
  function train (line 127) | def train(train_dataset,
  function initialize_plot (line 265) | def initialize_plot():
  function update_curve (line 276) | def update_curve(ax, line, x_buffer, y_buffer):
  function logger_print (line 290) | def logger_print(epoch,
  function visualize_lifting_results (line 323) | def visualize_lifting_results(data,
  function evaluate (line 395) | def evaluate(eval_dataset,

FILE: libs/visualization/debug.py
  function draw_circles (line 18) | def draw_circles(ndarr,
  function save_batch_image_with_joints (line 51) | def save_batch_image_with_joints(batch_image,
  function save_batch_heatmaps (line 83) | def save_batch_heatmaps(batch_image,
  function save_debug_images (line 151) | def save_debug_images(epoch,

FILE: libs/visualization/egonet_utils.py
  function plot_2d_objects (line 14) | def plot_2d_objects(img_path, record, color_dict):
  function plot_3d_objects (line 62) | def plot_3d_objects(prediction, target, pose_vecs_gt, record, color):

FILE: libs/visualization/points.py
  function check_points (line 13) | def check_points(points, dimension):
  function set_3d_axe_limits (line 26) | def set_3d_axe_limits(ax, points=None, center=None, radius=None, ratio=1...
  function plot_3d_points (line 48) | def plot_3d_points(ax,
  function plot_lines (line 96) | def plot_lines(ax,
  function plot_mesh (line 132) | def plot_mesh(ax, vertices, faces, color='grey'):
  function plot_3d_coordinate_system (line 149) | def plot_3d_coordinate_system(ax,
  function plot_3d_bbox (line 173) | def plot_3d_bbox(ax,
  function plot_2d_bbox (line 193) | def plot_2d_bbox(ax,
  function plot_comparison_relative (line 214) | def plot_comparison_relative(points_pred, points_gt):
  function plot_scene_3dbox (line 244) | def plot_scene_3dbox(points_pred, points_gt=None, ax=None, color='r'):
  function get_area (line 270) | def get_area(points, indices, preserve_points=False):
  function interpolate (line 284) | def interpolate(start, end, num_interp):
  function get_interpolated_points (line 293) | def get_interpolated_points(points, indices, num_interp):
  function draw_pose_vecs (line 303) | def draw_pose_vecs(ax, pose_vecs=None, color='black'):
  function get_bbox_3d (line 321) | def get_bbox_3d(points, add_center=False, interp_style=""):
  function ray_intersect_triangle (line 364) | def ray_intersect_triangle(p0, p1, triangle):
  function get_visibility (line 414) | def get_visibility(box3d, triangles):

FILE: tools/inference.py
  function filter_detection (line 32) | def filter_detection(detected, thres=0.7):
  function merge (line 46) | def merge(dict_a, dict_b):
  function collate_dict (line 51) | def collate_dict(dict_list):
  function my_collate_fn (line 57) | def my_collate_fn(batch):
  function filter_conf (line 63) | def filter_conf(record, thres=0.0):
  function gather_dict (line 80) | def gather_dict(request,
  function make_output_dir (line 129) | def make_output_dir(cfgs, name):
  function inference (line 136) | def inference(testset, model, results, cfgs):
  function generate_empty_file (line 201) | def generate_empty_file(output_dir, label_dir):
  function main (line 215) | def main():

FILE: tools/inference_legacy.py
  function prepare_models (line 33) | def prepare_models(cfgs, is_cuda=True):
  function modify_bbox (line 59) | def modify_bbox(bbox, target_ar, enlarge=1.1):
  function crop_single_instance (line 67) | def crop_single_instance(img, bbox, resolution, pth_trans=None, xy_dict=...
  function crop_instances (line 91) | def crop_instances(annot_dict,
  function get_keypoints (line 149) | def get_keypoints(instances,
  function kpts_to_euler (line 211) | def kpts_to_euler(template, prediction):
  function get_template (line 225) | def get_template(prediction, interp_coef=[0.332, 0.667]):
  function get_observation_angle_trans (line 252) | def get_observation_angle_trans(euler_angles, translations):
  function get_observation_angle_proj (line 268) | def get_observation_angle_proj(euler_angles, kpts, K):
  function get_6d_rep (line 287) | def get_6d_rep(predictions, ax=None, color="black"):
  function format_str_submission (line 308) | def format_str_submission(roll, pitch, yaw, x, y, z, score):
  function get_instance_str (line 322) | def get_instance_str(dic):
  function get_pred_str (line 341) | def get_pred_str(record):
  function lift_2d_to_3d (line 360) | def lift_2d_to_3d(records, model, stats, template, cuda=True):
  function filter_detection (line 379) | def filter_detection(detected, thres=0.7):
  function add_orientation_arrow (line 393) | def add_orientation_arrow(record):
  function process_batch (line 417) | def process_batch(images,
  function to_npy (line 466) | def to_npy(tensor):
  function refine_with_perfect_size (line 475) | def refine_with_perfect_size(pred,
  function refine_with_predicted_bbox (line 518) | def refine_with_predicted_bbox(pred,
  function draw_pose_vecs (line 549) | def draw_pose_vecs(ax, pose_vecs=None, color='black'):
  function refine_solution (line 567) | def refine_solution(est_3d,
  function gather_lifting_results (line 597) | def gather_lifting_results(record,
  function save_txt_file (line 693) | def save_txt_file(img_path, prediction, params):
  function refine_one_image (line 705) | def refine_one_image(img_path,
  function post_process (line 819) | def post_process(records,
  function merge (line 844) | def merge(dict_a, dict_b):
  function collate_dict (line 849) | def collate_dict(dict_list):
  function my_collate_fn (line 855) | def my_collate_fn(batch):
  function filter_conf (line 861) | def filter_conf(record, thres=0.0):
  function gather_dict (line 878) | def gather_dict(request, references, filter_c=True):
  function inference (line 918) | def inference(testset, model_settings, results, cfgs):
  function generate_empty_file (line 1012) | def generate_empty_file(output_dir, label_dir):
  function main (line 1024) | def main():

FILE: tools/kitti-eval/evaluate_object_3d.cpp
  type DIFFICULTY (line 37) | enum DIFFICULTY{EASY=0, MODERATE=1, HARD=2}
  type METRIC (line 40) | enum METRIC{IMAGE=0, GROUND=1, BOX3D=2}
  type CLASSES (line 48) | enum CLASSES{CAR=0, PEDESTRIAN=1, CYCLIST=2}
  function initGlobals (line 61) | void initGlobals () {
  type tPrData (line 72) | struct tPrData {
    method tPrData (line 78) | tPrData () :
  type tBox (line 83) | struct tBox {
    method tBox (line 90) | tBox (string type, double x1,double y1,double x2,double y2,double alph...
  type tGroundtruth (line 95) | struct tGroundtruth {
    method tGroundtruth (line 102) | tGroundtruth () :
    method tGroundtruth (line 104) | tGroundtruth (tBox box,double truncation,int32_t occlusion) :
    method tGroundtruth (line 106) | tGroundtruth (string type,double x1,double y1,double x2,double y2,doub...
  type tDetection (line 111) | struct tDetection {
    method tDetection (line 117) | tDetection ():
    method tDetection (line 119) | tDetection (tBox box,double thresh) :
    method tDetection (line 121) | tDetection (string type,double x1,double y1,double x2,double y2,double...
  function loadDetections (line 131) | vector<tDetection> loadDetections(string file_name, bool &compute_aos,
  function loadGroundtruth (line 178) | vector<tGroundtruth> loadGroundtruth(string file_name,bool &success) {
  function saveStats (line 204) | void saveStats (const vector<double> &precision, const vector<double> &a...
  function imageBoxOverlap (line 227) | inline double imageBoxOverlap(tBox a, tBox b, int32_t criterion=-1){
  function imageBoxOverlap (line 263) | inline double imageBoxOverlap(tDetection a, tGroundtruth b, int32_t crit...
  function Polygon (line 269) | Polygon toPolygon(const T& g) {
  function groundBoxOverlap (line 294) | inline double groundBoxOverlap(tDetection d, tGroundtruth g, int32_t cri...
  function box3DOverlap (line 317) | inline double box3DOverlap(tDetection d, tGroundtruth g, int32_t criteri...
  function getThresholds (line 346) | vector<double> getThresholds(vector<double> &v, double n_groundtruth){
  function cleanData (line 381) | void cleanData(CLASSES current_class, const vector<tGroundtruth> &gt, co...
  function tPrData (line 455) | tPrData computeStatistics(CLASSES current_class, const vector<tGroundtru...
    method tPrData (line 78) | tPrData () :
  function eval_class (line 619) | bool eval_class (FILE *fp_det, FILE *fp_ori, CLASSES current_class,
  function saveAndPlotPlots (line 705) | void saveAndPlotPlots(string dir_name,string file_name,string obj_type,v...
  function eval (line 768) | bool eval(string result_sha,Mail* mail){
  function main (line 887) | int32_t main (int32_t argc,char *argv[]) {

FILE: tools/kitti-eval/evaluate_object_3d_offline.cpp
  type DIFFICULTY (line 37) | enum DIFFICULTY{EASY=0, MODERATE=1, HARD=2}
  type METRIC (line 40) | enum METRIC{IMAGE=0, GROUND=1, BOX3D=2}
  type CLASSES (line 48) | enum CLASSES{CAR=0, PEDESTRIAN=1, CYCLIST=2}
  function initGlobals (line 61) | void initGlobals () {
  type tPrData (line 72) | struct tPrData {
    method tPrData (line 78) | tPrData () :
  type tBox (line 83) | struct tBox {
    method tBox (line 90) | tBox (string type, double x1,double y1,double x2,double y2,double alph...
  type tGroundtruth (line 95) | struct tGroundtruth {
    method tGroundtruth (line 102) | tGroundtruth () :
    method tGroundtruth (line 104) | tGroundtruth (tBox box,double truncation,int32_t occlusion) :
    method tGroundtruth (line 106) | tGroundtruth (string type,double x1,double y1,double x2,double y2,doub...
  type tDetection (line 111) | struct tDetection {
    method tDetection (line 117) | tDetection ():
    method tDetection (line 119) | tDetection (tBox box,double thresh) :
    method tDetection (line 121) | tDetection (string type,double x1,double y1,double x2,double y2,double...
  function loadDetections (line 131) | vector<tDetection> loadDetections(string file_name, bool &compute_aos,
  function loadGroundtruth (line 178) | vector<tGroundtruth> loadGroundtruth(string file_name,bool &success) {
  function saveStats (line 204) | void saveStats (const vector<double> &precision, const vector<double> &a...
  function imageBoxOverlap (line 227) | inline double imageBoxOverlap(tBox a, tBox b, int32_t criterion=-1){
  function imageBoxOverlap (line 263) | inline double imageBoxOverlap(tDetection a, tGroundtruth b, int32_t crit...
  function Polygon (line 269) | Polygon toPolygon(const T& g) {
  function groundBoxOverlap (line 294) | inline double groundBoxOverlap(tDetection d, tGroundtruth g, int32_t cri...
  function box3DOverlap (line 317) | inline double box3DOverlap(tDetection d, tGroundtruth g, int32_t criteri...
  function getThresholds (line 346) | vector<double> getThresholds(vector<double> &v, double n_groundtruth){
  function cleanData (line 381) | void cleanData(CLASSES current_class, const vector<tGroundtruth> &gt, co...
  function tPrData (line 456) | tPrData computeStatistics(CLASSES current_class, const vector<tGroundtru...
    method tPrData (line 78) | tPrData () :
  function eval_class (line 622) | bool eval_class (FILE *fp_det, FILE *fp_ori, CLASSES current_class,
  function saveAndPlotPlots (line 708) | void saveAndPlotPlots(string dir_name,string file_name,string obj_type,v...
  function getEvalIndices (line 778) | vector<int32_t> getEvalIndices(const string& result_dir) {
  function eval (line 795) | bool eval(string gt_dir, string result_dir, Mail* mail){
  function main (line 917) | int32_t main (int32_t argc,char *argv[]) {

FILE: tools/kitti-eval/evaluate_object_3d_offline_r40.cpp
  type DIFFICULTY (line 37) | enum DIFFICULTY{EASY=0, MODERATE=1, HARD=2}
  type METRIC (line 40) | enum METRIC{IMAGE=0, GROUND=1, BOX3D=2}
  type CLASSES (line 48) | enum CLASSES{CAR=0, PEDESTRIAN=1, CYCLIST=2}
  function initGlobals (line 61) | void initGlobals () {
  type tPrData (line 72) | struct tPrData {
    method tPrData (line 78) | tPrData () :
  type tBox (line 83) | struct tBox {
    method tBox (line 90) | tBox (string type, double x1,double y1,double x2,double y2,double alph...
  type tGroundtruth (line 95) | struct tGroundtruth {
    method tGroundtruth (line 102) | tGroundtruth () :
    method tGroundtruth (line 104) | tGroundtruth (tBox box,double truncation,int32_t occlusion) :
    method tGroundtruth (line 106) | tGroundtruth (string type,double x1,double y1,double x2,double y2,doub...
  type tDetection (line 111) | struct tDetection {
    method tDetection (line 117) | tDetection ():
    method tDetection (line 119) | tDetection (tBox box,double thresh) :
    method tDetection (line 121) | tDetection (string type,double x1,double y1,double x2,double y2,double...
  function loadDetections (line 131) | vector<tDetection> loadDetections(string file_name, bool &compute_aos,
  function loadGroundtruth (line 178) | vector<tGroundtruth> loadGroundtruth(string file_name,bool &success) {
  function saveStats (line 204) | void saveStats (const vector<double> &precision, const vector<double> &a...
  function imageBoxOverlap (line 227) | inline double imageBoxOverlap(tBox a, tBox b, int32_t criterion=-1){
  function imageBoxOverlap (line 263) | inline double imageBoxOverlap(tDetection a, tGroundtruth b, int32_t crit...
  function Polygon (line 269) | Polygon toPolygon(const T& g) {
  function groundBoxOverlap (line 294) | inline double groundBoxOverlap(tDetection d, tGroundtruth g, int32_t cri...
  function box3DOverlap (line 317) | inline double box3DOverlap(tDetection d, tGroundtruth g, int32_t criteri...
  function getThresholds (line 346) | vector<double> getThresholds(vector<double> &v, double n_groundtruth){
  function cleanData (line 381) | void cleanData(CLASSES current_class, const vector<tGroundtruth> &gt, co...
  function tPrData (line 456) | tPrData computeStatistics(CLASSES current_class, const vector<tGroundtru...
    method tPrData (line 78) | tPrData () :
  function eval_class (line 622) | bool eval_class (FILE *fp_det, FILE *fp_ori, CLASSES current_class,
  function saveAndPlotPlots (line 708) | void saveAndPlotPlots(string dir_name,string file_name,string obj_type,v...
  function getEvalIndices (line 778) | vector<int32_t> getEvalIndices(const string& result_dir) {
  function eval (line 795) | bool eval(string gt_dir, string result_dir, Mail* mail){
  function main (line 917) | int32_t main (int32_t argc,char *argv[]) {

FILE: tools/kitti-eval/mail.h
  function class (line 8) | class Mail {

FILE: tools/train_IGRs.py
  function choose_loss_func (line 27) | def choose_loss_func(model_settings, cfgs):
  function train (line 49) | def train(model, model_settings, GPUs, cfgs, logger, final_output_dir):
  function evaluate (line 108) | def evaluate(model, model_settings, GPUs, cfgs, logger, final_output_dir...
  function main (line 127) | def main():

FILE: tools/train_lifting.py
  function main (line 24) | def main():
Condensed preview — 56 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (497K chars).
[
  {
    "path": ".gitignore",
    "chars": 135,
    "preview": "#/*\n**/__pycache__\n.spyproject/\n*.log\n*.ini\n*.bak\n*.pth\n*.csv\n*.jpg\n*.png\n*.pdf\n/tools/kitti-eval/evaluate_object_3d_off"
  },
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2022 Shichao Li\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 5466,
    "preview": "[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/exploring-intermediate-representation-f"
  },
  {
    "path": "configs/KITTI_inference:demo.yml",
    "chars": 3829,
    "preview": "# This is a YAML file storing experimental configurations for KITTI dataset\n\n## general settings\nname: 'refine a given s"
  },
  {
    "path": "configs/KITTI_inference:test_submission.yml",
    "chars": 3944,
    "preview": "# YAML file storing experimental configurations for KITTI dataset\n\n## general settings\nname: 'produce vehicle pose predi"
  },
  {
    "path": "configs/KITTI_train_IGRs.yml",
    "chars": 4953,
    "preview": "# YAML file storing experimental configurations for training on KITTI dataset\n\n## general settings\nname: 'kitti_kpt_loc'"
  },
  {
    "path": "configs/KITTI_train_IGRs_Ped.yml",
    "chars": 4951,
    "preview": "# YAML file storing experimental configurations for training on KITTI dataset for the Pedestrian class\n\n## general setti"
  },
  {
    "path": "configs/KITTI_train_lifting.yml",
    "chars": 3052,
    "preview": "# YAML file storing experimental configurations for KITTI dataset\n\n## general settings\nname: 'lifter'\nexp_type: '2dto3d'"
  },
  {
    "path": "docs/demo.md",
    "chars": 1054,
    "preview": "Firstly you need to prepare the dataset and pre-trained models as described [here](https://github.com/Nicholasli1995/Ego"
  },
  {
    "path": "docs/inference.md",
    "chars": 1851,
    "preview": "Firstly you need to prepare the dataset and pre-trained models as described [here](https://github.com/Nicholasli1995/Ego"
  },
  {
    "path": "docs/preparation.md",
    "chars": 2399,
    "preview": "## Data Preparation \nYou need to download KITTI dataset [here](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_"
  },
  {
    "path": "docs/spec-list.txt",
    "chars": 15851,
    "preview": "# This file may be used to create an environment using:\n# $ conda create --name <env> --file <this file>\n# platform: lin"
  },
  {
    "path": "docs/training.md",
    "chars": 1804,
    "preview": "Firstly you need to prepare the dataset as described [here](https://github.com/Nicholasli1995/EgoNet/blob/master/docs/pr"
  },
  {
    "path": "libs/arguments/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/arguments/parse.py",
    "chars": 1398,
    "preview": "\"\"\"\nArgument parser for command line inputs and experiment configuration file.\n\nAuthor: Shichao Li\nContact: nicholas.li@"
  },
  {
    "path": "libs/common/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/common/format.py",
    "chars": 2542,
    "preview": "\"\"\"\nMethods for formatted output.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\nimport os\n\nfrom copy impor"
  },
  {
    "path": "libs/common/img_proc.py",
    "chars": 28060,
    "preview": "\"\"\"\nImage processing utilities.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport cv2\nimport numpy as "
  },
  {
    "path": "libs/common/transformation.py",
    "chars": 4855,
    "preview": "\"\"\"\r\nCoordinate transformation functions.\r\n\r\nAuthor: Shichao Li\r\nContact: nicholas.li@connect.ust.hk\r\n\"\"\"\r\n\r\nimport nump"
  },
  {
    "path": "libs/common/utils.py",
    "chars": 6249,
    "preview": "\"\"\"\nCommon utilities.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nim"
  },
  {
    "path": "libs/dataset/KITTI/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/dataset/KITTI/car_instance.py",
    "chars": 64415,
    "preview": "\"\"\"\nKITTI dataset implemented as PyTorch dataset object.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nim"
  },
  {
    "path": "libs/dataset/__init__.py",
    "chars": 60,
    "preview": "#import libs.dataset.ApolloScape\nimport libs.dataset.KITTI\n\n"
  },
  {
    "path": "libs/dataset/basic/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/dataset/basic/basic_classes.py",
    "chars": 1821,
    "preview": "\"\"\"\nBasic classes for customized dataset classes to inherit.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\""
  },
  {
    "path": "libs/dataset/normalization/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/dataset/normalization/operations.py",
    "chars": 1643,
    "preview": "\"\"\"\nDataset normalization operations.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport numpy as np\n\nd"
  },
  {
    "path": "libs/logger/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/logger/logger.py",
    "chars": 1210,
    "preview": "\"\"\"\nBasic logging functions.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport logging\nimport os\nimpor"
  },
  {
    "path": "libs/loss/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/loss/function.py",
    "chars": 12427,
    "preview": "\"\"\"\nLoss functions.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimpo"
  },
  {
    "path": "libs/metric/criterions.py",
    "chars": 22520,
    "preview": "\"\"\"\nMetric functions used for validation.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport libs.commo"
  },
  {
    "path": "libs/model/FCmodel.py",
    "chars": 4440,
    "preview": "\"\"\"\nFully-connected model architecture for processing 1D data.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\""
  },
  {
    "path": "libs/model/__init__.py",
    "chars": 75,
    "preview": "import libs.model.heatmapModel.hrnet\nimport libs.model.heatmapModel.resnet\n"
  },
  {
    "path": "libs/model/egonet.py",
    "chars": 22766,
    "preview": "\"\"\"\nA PyTorch implementation of Ego-Net.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport torch\nimpor"
  },
  {
    "path": "libs/model/heatmapModel/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "libs/model/heatmapModel/hrnet.py",
    "chars": 26117,
    "preview": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed un"
  },
  {
    "path": "libs/model/heatmapModel/resnet.py",
    "chars": 9460,
    "preview": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed un"
  },
  {
    "path": "libs/optimizer/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/optimizer/optimizer.py",
    "chars": 1818,
    "preview": "\"\"\"\nOptimization utilities.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\nimport torch\n\ndef prepare_optim("
  },
  {
    "path": "libs/trainer/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/trainer/accuracy.py",
    "chars": 3089,
    "preview": "\"\"\"\nDeprecated. Will be deleted in a future version.\nPre-defined accuracy functions.\n\"\"\"\n\nimport libs.common.img_proc as"
  },
  {
    "path": "libs/trainer/trainer.py",
    "chars": 19764,
    "preview": "\"\"\"\nUtilities for training and validation.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport libs.mode"
  },
  {
    "path": "libs/visualization/__init__.py",
    "chars": 69,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nEmpty file.\n\"\"\"\n\n\n"
  },
  {
    "path": "libs/visualization/debug.py",
    "chars": 6917,
    "preview": "\"\"\"\nUtilities for saving debugging images.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nfrom libs.common"
  },
  {
    "path": "libs/visualization/egonet_utils.py",
    "chars": 4223,
    "preview": "\"\"\"\nVisualization utilities for Ego-Net inference.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimport c"
  },
  {
    "path": "libs/visualization/points.py",
    "chars": 15412,
    "preview": "\"\"\"\nSimple visualization utilities for 2D and 3D points based on Matplotlib.\n\nAuthor: Shichao Li\nContact: nicholas.li@co"
  },
  {
    "path": "tools/inference.py",
    "chars": 10809,
    "preview": "\"\"\"\nInference of Ego-Net on KITTI dataset.\n\nThe user can provide the 3D bounding boxes predicted by other 3D object dete"
  },
  {
    "path": "tools/inference_legacy.py",
    "chars": 44576,
    "preview": "\"\"\"\nThis is the legacy inference code which includes some debugging functions.\nYou don't need to read this file to use E"
  },
  {
    "path": "tools/kitti-eval/README.md",
    "chars": 863,
    "preview": "# kitti_eval\n\n`evaluate_object_3d_offline.cpp`evaluates your KITTI detection locally on your own computer using your val"
  },
  {
    "path": "tools/kitti-eval/evaluate_object_3d.cpp",
    "chars": 33046,
    "preview": "#include <iostream>\n#include <algorithm>\n#include <stdio.h>\n#include <math.h>\n#include <vector>\n#include <numeric>\n#incl"
  },
  {
    "path": "tools/kitti-eval/evaluate_object_3d_offline.cpp",
    "chars": 33867,
    "preview": "#include <iostream>\n#include <algorithm>\n#include <stdio.h>\n#include <math.h>\n#include <vector>\n#include <numeric>\n#incl"
  },
  {
    "path": "tools/kitti-eval/evaluate_object_3d_offline_r40.cpp",
    "chars": 33867,
    "preview": "#include <iostream>\n#include <algorithm>\n#include <stdio.h>\n#include <math.h>\n#include <vector>\n#include <numeric>\n#incl"
  },
  {
    "path": "tools/kitti-eval/mail.h",
    "chars": 811,
    "preview": "#ifndef MAIL_H\n#define MAIL_H\n\n#include <stdio.h>\n#include <stdarg.h>\n#include <string.h>\n\nclass Mail {\n\npublic:\n\n  Mail"
  },
  {
    "path": "tools/train_IGRs.py",
    "chars": 6285,
    "preview": "\"\"\"\nTraining the coordinate localization sub-network.\n\nAuthor: Shichao Li\nContact: nicholas.li@connect.ust.hk\n\"\"\"\n\nimpor"
  },
  {
    "path": "tools/train_lifting.py",
    "chars": 2193,
    "preview": "\"\"\"\nTraining the sub-network \\mathcal{L}() that predicts 3D cuboid \ngiven 2D screen coordinates as input.\n\nAuthor: Shich"
  }
]

About this extraction

This page contains the full source code of the Nicholasli1995/EgoNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 56 files (467.4 KB), approximately 121.0k tokens, and a symbol index with 428 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!