Full Code of DecaYale/RNNPose for AI

main ff223f3cb6bf cached
126 files
12.7 MB
223.6k tokens
1087 symbols
1 requests
Download .txt
Showing preview only (846K chars total). Download the full file or copy to clipboard to get everything.
Repository: DecaYale/RNNPose
Branch: main
Commit: ff223f3cb6bf
Files: 126
Total size: 12.7 MB

Directory structure:
gitextract_p7a82okx/

├── .gitignore
├── LICENSE.md
├── README.md
├── builder/
│   ├── __init__.py
│   ├── dataset_builder.py
│   ├── input_reader_builder.py
│   ├── losses_builder.py
│   ├── lr_scheduler_builder.py
│   ├── optimizer_builder.py
│   └── rnnpose_builder.py
├── config/
│   ├── default.py
│   └── linemod/
│       ├── copy.sh
│       ├── copy_occ.sh
│       ├── template_fw0.5.yml
│       └── template_fw0.5_occ.yml
├── data/
│   ├── __init__.py
│   ├── dataset.py
│   ├── linemod/
│   │   └── linemod_config.py
│   ├── linemod_dataset.py
│   ├── preprocess.py
│   ├── transforms.py
│   └── ycb/
│       └── basic.py
├── doc/
│   └── prepare_data.md
├── docker/
│   ├── Dockerfile
│   └── freeze.yml
├── geometry/
│   ├── __init__.py
│   ├── cholesky.py
│   ├── diff_render.py
│   ├── diff_render_optim.py
│   ├── einsum.py
│   ├── intrinsics.py
│   ├── projective_ops.py
│   ├── se3.py
│   └── transformation.py
├── model/
│   ├── CFNet.py
│   ├── HybridNet.py
│   ├── PoseRefiner.py
│   ├── RNNPose.py
│   ├── descriptor2D.py
│   ├── descriptor3D.py
│   └── losses.py
├── scripts/
│   ├── compile_3rdparty.sh
│   ├── eval.sh
│   ├── eval_lmocc.sh
│   ├── run_dataformatter.sh
│   ├── run_datainfo_generation.sh
│   └── train.sh
├── thirdparty/
│   ├── kpconv/
│   │   ├── __init__.py
│   │   ├── cpp_wrappers/
│   │   │   ├── compile_wrappers.sh
│   │   │   ├── cpp_neighbors/
│   │   │   │   ├── build.bat
│   │   │   │   ├── neighbors/
│   │   │   │   │   ├── neighbors.cpp
│   │   │   │   │   └── neighbors.h
│   │   │   │   ├── setup.py
│   │   │   │   └── wrapper.cpp
│   │   │   ├── cpp_subsampling/
│   │   │   │   ├── build.bat
│   │   │   │   ├── grid_subsampling/
│   │   │   │   │   ├── grid_subsampling.cpp
│   │   │   │   │   └── grid_subsampling.h
│   │   │   │   ├── setup.py
│   │   │   │   └── wrapper.cpp
│   │   │   └── cpp_utils/
│   │   │       ├── cloud/
│   │   │       │   ├── cloud.cpp
│   │   │       │   └── cloud.h
│   │   │       └── nanoflann/
│   │   │           └── nanoflann.hpp
│   │   ├── kernels/
│   │   │   ├── dispositions/
│   │   │   │   └── k_015_center_3D.ply
│   │   │   └── kernel_points.py
│   │   ├── kpconv_blocks.py
│   │   └── lib/
│   │       ├── __init__.py
│   │       ├── ply.py
│   │       ├── timer.py
│   │       └── utils.py
│   ├── nn/
│   │   ├── _ext.c
│   │   ├── nn_utils.py
│   │   ├── setup.py
│   │   └── src/
│   │       ├── ext.h
│   │       └── nearest_neighborhood.cu
│   ├── raft/
│   │   ├── corr.py
│   │   ├── extractor.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── augmentor.py
│   │       ├── flow_viz.py
│   │       ├── frame_utils.py
│   │       └── utils.py
│   └── vsd/
│       └── inout.py
├── tools/
│   ├── eval.py
│   ├── generate_data_info_deepim_0_orig.py
│   ├── generate_data_info_deepim_1_syn.py
│   ├── generate_data_info_deepim_2_posecnnval.py
│   ├── generate_data_info_v2_deepim.py
│   ├── train.py
│   └── transform_data_format.py
├── torchplus/
│   ├── __init__.py
│   ├── metrics.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── functional.py
│   │   └── modules/
│   │       ├── __init__.py
│   │       ├── common.py
│   │       └── normalization.py
│   ├── ops/
│   │   ├── __init__.py
│   │   └── array_ops.py
│   ├── tools.py
│   └── train/
│       ├── __init__.py
│       ├── checkpoint.py
│       ├── common.py
│       ├── fastai_optim.py
│       ├── learning_schedules.py
│       ├── learning_schedules_fastai.py
│       └── optim.py
├── utils/
│   ├── __init__.py
│   ├── config_io.py
│   ├── distributed_utils.py
│   ├── eval_metric.py
│   ├── furthest_point_sample.py
│   ├── geometric.py
│   ├── img_utils.py
│   ├── log_tool.py
│   ├── pose_utils.py
│   ├── pose_utils_np.py
│   ├── progress_bar.py
│   ├── rand_utils.py
│   ├── singleton.py
│   ├── timer.py
│   ├── util.py
│   └── visualize.py
└── weights/
    ├── gru_update.pth
    ├── img_fea_enc.pth
    └── superpoint_v1.pth

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
**/*.old
**/*.bak

.DS_Store
# Created by https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks
# Edit at https://www.toptal.com/developers/gitignore?templates=vscode,python,jupyternotebooks

### JupyterNotebooks ###
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/

.ipynb_checkpoints
*/.ipynb_checkpoints/*

# IPython
profile_default/
ipython_config.py

# Remove previous ipynb_checkpoints
#   git rm -r .ipynb_checkpoints/

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
#lib/
#lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
pytestdebug.log

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
doc/_build/

# PyBuilder
target/

# Jupyter Notebook

# IPython

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pythonenv*

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# profiling data
.prof

### vscode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace

# End of https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks


================================================
FILE: LICENSE.md
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization

[Yan Xu](https://decayale.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Guofeng Zhang](http://www.cad.zju.edu.cn/home/gfzhang/), [Xiaogang Wang](https://www.ee.cuhk.edu.hk/en-gb/people/academic-staff/professors/prof-xiaogang-wang), [Hongsheng Li](http://www.ee.cuhk.edu.hk/~hsli/). 

*Conference on Computer Vision and Pattern Recognition (CVPR), 2022.*

[[Paper]](https://scholar.google.com/scholar?hl=zh-CN&as_sdt=0%2C5&q=RNNPose%3A+Recurrent+6-DoF+Object+Pose+Refinement+with+Robust+Correspondence+Field+Estimation+and+Pose+Optimization&btnG=)




## 1. Framework 
The basic pipeline of our proposed RNNPose.  (a) Before refinement, a reference image is rendered according to the object initial pose (shown in a fused view).
(b) Our RNN-based framework recurrently refines the object pose based on the estimated correspondence field between the reference and target images. The pose is optimized to be consistent with the reliable correspondence estimations highlighted by the similarity score map (built from learned 3D-2D descriptors) via differentiable LM optimization.  (c) The output refined pose.  

<!-- ![image info](./demo/framework.png) -->
<p align="center">
<img src="./demo/idea.png" alt="alt text" width="450"/>
</p>

## 2. Pose Estimation with Occlusions and Erroneous Pose Initializations


### Estimated Poses and Intermediate System Outputs from Different Recurrent Iterations. 

<p align="center">
 <img src="demo/ape_short_small.gif" alt="animated" height=400/><img src="demo/driller_short_small.gif" alt="animated" height=400/>
</p>


### Pose Estimates with Erroneous Pose Initializations
Visualization of our pose estimations (first row) on Occlusion LINEMOD dataset and the similarity score maps (second row) for downweighting unreliable correspondences during pose optimization. 
For pose visualization, the white boxes represent the erroneous initial poses, the red boxes are estimated by our algorithm and the ground-truth boxes are in blue. Here, the initial poses for pose refinement are originally from PVNet but added with significant disturbances for robustness testing. 
<center class="half">
  <img src="./demo/est_vis.png" height=200 > 
</center>


## 3. Installation 
### Install the Docker 
A dockerfile is provided to help with the environment setup. 
You need to install [docker](https://docs.docker.com/get-docker/) and [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) first and then set up the docker image and start up a container with the following commands: 

```
cd RNNPose/docker
sudo docker build -t rnnpose .    
sudo docker run  -it  --runtime=nvidia --ipc=host  --volume="HOST_VOLUME_YOU_WANT_TO_MAP:DOCKER_VOLUME"  -e DISPLAY=$DISPLAY -e QT_X11_NO_MITSHM=1  rnnpose bash

```
If you are not familiar with [docker](https://docs.docker.com/get-docker/), you could also install the dependencies manually following the provided dockerfile.  

### Compile the Dependencies
```
cd RNNPose/scripts
bash compile_3rdparty.sh
```


## 4. Data Preparation
We follow [DeepIM](https://github.com/liyi14/mx-DeepIM) and [PVNet](https://github.com/zju3dv/pvnet-rendering) to preprocess the training data for LINEMOD. 
You could follow the steps [here](doc/prepare_data.md) for data preparation. 



## 5. Test with the Pretrained Models
We train our model with the mixture of the real data and the synthetic data on LINEMOD dataset. 
<!-- and evaluate the trained models on the test set of LINEMOD and LINEMOD OCCLUSION datasets.  -->
The trained models on the LINEMOD dataset have been uploaded to the [OneDrive](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/ESPTVyUryHdGl65fRAxN51gBBayJJb9NpCqWA-tY2CFKJQ?e=R9bcLW). 
You can download them 
and put them into the directory *weight/* for testing. 


An example bash script is provided below for reference. 

```
export PYTHONPATH="$PROJECT_ROOT_PATH:$PYTHONPATH"
export PYTHONPATH="$PROJECT_ROOT_PATH/thirdparty:$PYTHONPATH"

seq=cat
gpu=1
start_gpu_id=0
mkdir $model_dir

train_file=/home/yxu/Projects/Works/RNNPose_release/tools/eval.py
config_path=/mnt/workspace/Works/RNNPose_release/config/linemod/"$seq"_fw0.5.yml
pretrain=$PROJECT_ROOT_PATH/weights/trained_models/"$seq".tckpt

python -u $train_file multi_proc_train  \
        --config_path $config_path \
        --model_dir $model_dir/results \
        --use_dist True \
        --dist_port 10000 \
        --gpus_per_node $gpu \
        --optim_eval True \
        --use_apex True \
        --world_size $gpu \
        --start_gpu_id $start_gpu_id \
        --pretrained_path $pretrain 

```

Note that you need to specify the PROJECT_ROOT_PATH, i.e. the absolute directory of the project folder *RNNPose* and modify the respective data paths in the configuration files to the locations of downloaded data before executing the commands. You could also refer to the commands below for evaluation with our provide scripts.

### Evaluation on LINEMOD
```
bash scripts/eval.sh 
```

### Evaluation on LINEMOD OCCLUSION
```
bash scripts/eval_lmocc.sh

```

## Training from Scratch
An example training script is provided. 
```
bash scripts/train.sh 
```



## 6. Citation
If you find our code useful, please cite our paper. 
```
@inproceedings{xu2022rnnpose,
  title={RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization},
  author={Xu, Yan and Kwan-Yee Lin and Zhang, Guofeng and Wang, Xiaogang and  Li, Hongsheng},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2022}
}

@article{xu2024rnnpose,
  title={Rnnpose: 6-dof object pose estimation via recurrent correspondence field estimation and pose optimization},
  author={Xu, Yan and Lin, Kwan-Yee and Zhang, Guofeng and Wang, Xiaogang and Li, Hongsheng},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year={2024},
  publisher={IEEE}
```


## 7. Acknowledgement

The skeleton of this code is borrowed from [RSLO](https://github.com/DecaYale/RSLO). We also would like to thank the public codebases [PVNet](https://github.com/zju3dv/pvnet), [RAFT](https://github.com/princeton-vl/RAFT), [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) and [DeepV2D](https://github.com/princeton-vl/DeepV2D). 

<!-- ## TODO List and ETA
- [x] Inference code and pretrained models (25/12/2021)
- [ ] Training code
- [ ] Code cleaning and improvement -->







================================================
FILE: builder/__init__.py
================================================


================================================
FILE: builder/dataset_builder.py
================================================

from data.dataset import get_dataset_class
import numpy as np
from functools import partial
from data.preprocess import preprocess, preprocess_deepim, patch_crop


def build(input_reader_config,
          training,
          ):

    prep_cfg = input_reader_config.preprocess
    dataset_cfg = input_reader_config.dataset
    cfg = input_reader_config
    
    dataset_cls = get_dataset_class(dataset_cfg.dataset_class_name)

    if 0:# 'DeepIM' in dataset_cfg.dataset_class_name:
        # patch_cropper = partial(patch_crop, margin_ratio=0.2, output_size=256 )
        patch_cropper = None 
        prep_func = partial(preprocess_deepim, 
                        max_points=dataset_cfg.max_points,
                        correspondence_radius=prep_cfg.correspondence_radius_threshold,
                        patch_cropper=patch_cropper,
                        image_scale=prep_cfg.get('image_scale', 1),
        ) 

    else:
        prep_func = partial(preprocess, 
                        max_points=dataset_cfg.max_points,
                        correspondence_radius=prep_cfg.correspondence_radius_threshold,
                        image_scale=prep_cfg.get('image_scale', 1),
                        crop_param=prep_cfg.get('crop_param', None),
                        kp_3d_param=prep_cfg.get('kp_3d_param', {"kp_num":30} ),
                        use_coords_as_3d_feat=prep_cfg.get('use_coords_as_3d_feat', False)
        ) 
   

    dataset = dataset_cls(
        info_path=dataset_cfg.info_path,
        root_path=dataset_cfg.root_path,
        model_point_dim=dataset_cfg.model_point_dim,
        is_train=training,
        prep_func=prep_func,
        seq_names=dataset_cfg.get('seq_names', None),
        cfg=dataset_cfg
    )

    return dataset


================================================
FILE: builder/input_reader_builder.py
================================================

from torch.utils.data import Dataset

from builder import dataset_builder


class DatasetWrapper(Dataset):
    """ convert our dataset to Dataset class in pytorch.
    """

    def __init__(self, dataset):
        self._dataset = dataset

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        return self._dataset[idx]

    @property
    def dataset(self):
        return self._dataset


def build(input_reader_config,
          training,
          ) -> DatasetWrapper:

    dataset = dataset_builder.build(
        input_reader_config,
        training,
    )
    dataset = DatasetWrapper(dataset)
    return dataset


================================================
FILE: builder/losses_builder.py
================================================

from model import losses

def build(loss_config):

    criterions = {}
  
    criterions["metric_loss"] =losses.MetricLoss(configs=loss_config.metric_loss,)

    criterions["pose_loss"] = losses.PointAlignmentLoss(loss_weight=1)
    
    return criterions


================================================
FILE: builder/lr_scheduler_builder.py
================================================

from torchplus.train import learning_schedules_fastai as lsf
import torch
import numpy as np 

def build(optimizer_config, optimizer, total_step):

    optimizer_type = list(optimizer_config.keys())[0]

    if optimizer_type == 'rms_prop_optimizer':
        config = optimizer_config.rms_prop_optimizer
        lr_scheduler = _create_learning_rate_scheduler(
            config.learning_rate, optimizer, total_step=total_step)

    if optimizer_type == 'momentum_optimizer':
        config = optimizer_config.momentum_optimizer
        lr_scheduler = _create_learning_rate_scheduler(
            config.learning_rate, optimizer, total_step=total_step)

    if optimizer_type == 'adam_optimizer':
        config = optimizer_config.adam_optimizer
        lr_scheduler = _create_learning_rate_scheduler(
            config.learning_rate, optimizer, total_step=total_step)

    return lr_scheduler


def _create_learning_rate_scheduler(learning_rate_config, optimizer, total_step):
    """Create optimizer learning rate scheduler based on config.

    Args:
      learning_rate_config: A LearningRate proto message.

    Returns:
      A learning rate.

    Raises:
      ValueError: when using an unsupported input data type.
    """
    lr_scheduler = None
    # learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
    learning_rate_type = list(learning_rate_config.keys())[0]

    if learning_rate_type == 'multi_phase':
        config = learning_rate_config.multi_phase
        lr_phases = []
        mom_phases = []
        for phase_cfg in config.phases:
            lr_phases.append((phase_cfg.start, phase_cfg.lambda_func))
            mom_phases.append(
                (phase_cfg.start, phase_cfg.momentum_lambda_func))
        lr_scheduler = lsf.LRSchedulerStep(
            optimizer, total_step, lr_phases, mom_phases)



    if learning_rate_type == 'one_cycle':
        config = learning_rate_config.one_cycle

        if len(config.lr_maxs)>1:
          assert(len(config.lr_maxs)==4 )    
          lr_max=[]
          # for i in range(len(config.lr_maxs)):
          #   lr_max += [config.lr_maxs[i]]*optimizer.param_segs[i] 

          lr_max = np.array(list(config.lr_maxs) )
        else:
          lr_max = config.lr_max

        lr_scheduler = lsf.OneCycle(
            optimizer, total_step, lr_max, list(config.moms), config.div_factor, config.pct_start)
    if learning_rate_type == 'exponential_decay':
        config = learning_rate_config.exponential_decay
        lr_scheduler = lsf.ExponentialDecay(
            optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.staircase)
    if learning_rate_type == 'exponential_decay_warmup':
        config = learning_rate_config.exponential_decay_warmup
        lr_scheduler = lsf.ExponentialDecayWarmup(
            optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor,   config.div_factor,
            config.pct_start, config.staircase)
    if learning_rate_type == 'manual_stepping':
        config = learning_rate_config.manual_stepping
        lr_scheduler = lsf.ManualStepping(
            optimizer, total_step, list(config.boundaries), list(config.rates))

    if lr_scheduler is None:
        raise ValueError('Learning_rate %s not supported.' %
                         learning_rate_type)

    return lr_scheduler


================================================
FILE: builder/optimizer_builder.py
================================================
from torchplus.train import learning_schedules
from torchplus.train import optim
import torch
from torch import nn
from torchplus.train.fastai_optim import OptimWrapper, FastAIMixedOptim
from functools import partial


def children(m: nn.Module):
    "Get children of `m`."
    return list(m.children())


def num_children(m: nn.Module) -> int:
    "Get number of children modules in `m`."
    return len(children(m))

# return a list of smallest modules dy


def flatten_model(m):
    if m is None:
        return []
    return sum(
        map(flatten_model, m.children()), []) if num_children(m) else [m]


# def get_layer_groups(m): return [nn.Sequential(*flatten_model(m))]
def get_layer_groups(m): return [nn.ModuleList(flatten_model(m))]

def get_voxeLO_net_layer_groups(net):
    vfe_grp = get_layer_groups(net)#[0]

    other_grp = get_layer_groups(nn.Sequential(net._rotation_loss,
                    net._translation_loss,
                    net._pyramid_rotation_loss,
                        net._pyramid_translation_loss,
                     net._consistency_loss, 
                     ))

    return [vfe_grp, mfe_grp, op_grp,other_grp]


def get_voxeLO_net_layer_groups(net):
    vfe_grp = get_layer_groups(net.voxel_feature_extractor)#[0]
    mfe_grp = get_layer_groups(net.middle_feature_extractor)#[0]
    op_grp = get_layer_groups(net.odom_predictor)#[0]

    # other_grp = get_layer_groups(net._rotation_loss) +  \
    #     get_layer_groups(net._translation_loss) \
    #         + get_layer_groups(net._pyramid_rotation_loss) \
    #             + get_layer_groups(net._pyramid_translation_loss) \
    #                 + get_layer_groups(net._consistency_loss)\
    other_grp = get_layer_groups(nn.Sequential(net._rotation_loss,
                    net._translation_loss,
                    net._pyramid_rotation_loss,
                        net._pyramid_translation_loss,
                     net._consistency_loss, 
                     ))

    return [vfe_grp, mfe_grp, op_grp,other_grp]


def build(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):

    optimizer_type = list(optimizer_config.keys())[0]
    print("Optimizer:", optimizer_type)
    
    optimizer=None

    if optimizer_type == 'rms_prop_optimizer':
        config=optimizer_config.rms_prop_optimizer
        optimizer_func=partial(
            torch.optim.RMSprop,
            alpha=config.decay,
            momentum=config.momentum_optimizer_value,
            eps=config.epsilon)

    if optimizer_type == 'momentum_optimizer':
        config=optimizer_config.momentum_optimizer
        optimizer_func=partial(
            torch.optim.SGD,
            momentum=config.momentum_optimizer_value,
            eps=config.epsilon)

    if optimizer_type == 'adam_optimizer':
        config=optimizer_config.adam_optimizer
        if optimizer_config.fixed_weight_decay:
            optimizer_func=partial(
                torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad)
        else:
            # regular adam
            optimizer_func=partial(
                torch.optim.Adam, amsgrad=config.amsgrad)

    optimizer=OptimWrapper.create(
        optimizer_func,
        3e-3,
        get_layer_groups(net),
        # get_voxeLO_net_layer_groups(net),
        wd=config.weight_decay,
        true_wd=optimizer_config.fixed_weight_decay,
        bn_wd=True)
    print(hasattr(optimizer, "_amp_stash"), '_amp_stash')
    if optimizer is None:
        raise ValueError('Optimizer %s not supported.' % optimizer_type)

    if optimizer_config.use_moving_average:
        raise ValueError('torch don\'t support moving average')
    if name is None:
        # assign a name to optimizer for checkpoint system
        optimizer.name=optimizer_type
    else:
        optimizer.name=name
    return optimizer


================================================
FILE: builder/rnnpose_builder.py
================================================
from builder import losses_builder
from model.RNNPose import get_posenet_class
import model.RNNPose


def build(model_cfg,
          measure_time=False, testing=False):
    """build second pytorch instance.
    """

    criterions=losses_builder.build(model_cfg.loss)

    net = get_posenet_class(model_cfg.network_class_name)(
        criterions=criterions,
        opt=model_cfg)
    return net


================================================
FILE: config/default.py
================================================
from yacs.config import CfgNode as CN
from utils.singleton import Singleton
import os

def _merge_a_into_b(a, b):
    """Merge config dictionary a into config dictionary b, clobbering the
    options in b whenever they are also specified in a.
    """
    # if type(a) is not dict:
    if not isinstance(a, (dict)):
        return

    for k, v in a.items():
        # a must specify keys that are in b
        if not k in b:
            raise KeyError('{} is not a valid config key'.format(k))

        # the types must match, too
        old_type = type(b[k])
        if old_type is not type(v):
            if isinstance(b[k], np.ndarray):
                v = np.array(v, dtype=b[k].dtype)
            else:
                raise ValueError(('Type mismatch ({} vs. {}) '
                                'for config key: {}').format(type(b[k]),
                                                            type(v), k))

        # recursively merge dicts
        # if type(v) is dict:
        if isinstance(v, (dict)):
            try:
                _merge_a_into_b(a[k], b[k])
            except:
                print('Error under config key: {}'.format(k))
                raise
        else:
            b[k] = v



class Config(metaclass=Singleton):
    def __init__(self):
        ##############  ↓  Basic   ↓  ##############
        self.ROOT_CN = CN()
        self.ROOT_CN.BASIC = CN()
        self.ROOT_CN.BASIC.input_size=[480,640] #h,w
        self.ROOT_CN.BASIC.crop_size=[320,320] #h,w
        self.ROOT_CN.BASIC.zoom_crop_size=[320,320] #h,w
        self.ROOT_CN.BASIC.render_image_size=[320,320]#h,w
        self.ROOT_CN.BASIC.patch_num=64#h,w

        ##############  ↓  LM OPTIM   ↓  ##############
        self.ROOT_CN.LM=CN()
        self.ROOT_CN.LM.LM_LMBDA= 0.0001
        self.ROOT_CN.LM.EP_LMBDA=100

        ##############  ↓  data   ↓  ##############
        self.ROOT_CN.DATA=CN()
        self.ROOT_CN.DATA.OBJ_ROOT="" #h,w
        self.ROOT_CN.DATA.VOC_ROOT=f"{os.path.dirname(os.path.abspath(__file__)) }/../EXPDATA/" #h,w

    def __get_item__(self, key):
        return self.ROOT_CN.__getitem__(key)
    
    def merge(self, config_dict, sub_key=None):

        if sub_key is not None:
            _merge_a_into_b(config_dict, self.ROOT_CN[sub_key])
        else:
            _merge_a_into_b(config_dict, self.ROOT_CN)

##############  ↓  Model  ↓  ##############
# _CN.model = CN()
# _CN.model.input_size=[480,640]
# _CN.model.crop_size=[320,320] 
# def get_cfg_defaults():

def get_cfg(Node=None):
    """Get a yacs CfgNode object with default values for my_project."""
    # Return a clone so that the defaults will not be altered
    # This is for the "local variable" use pattern
    # return _CN.clone()
    if Node is None:
        return  Config()
    else:
        return Config().__get_item__(Node)


================================================
FILE: config/linemod/copy.sh
================================================
declare -a arr=("glue" "ape" "cat" "phone" "eggbox" "benchvise" "lamp" "camera" "can" "driller" "duck" "holepuncher" "iron"  )

#create training scripts
for seq in "${arr[@]}"
do
   echo "$seq"
   cat template_fw0.5.yml > "$seq"_fw0.5.yml
   sed -i "s/SEQ_NAME/$seq/g" "$seq"_fw0.5.yml
done

arraylength=${#arr[@]}


================================================
FILE: config/linemod/copy_occ.sh
================================================
declare -a arr=("glue" "ape" "cat" "phone" "eggbox" "benchvise" "lamp" "camera" "can" "driller" "duck" "holepuncher" "iron"  )

#create training scripts
for seq in "${arr[@]}"
do
   echo "$seq"
   cat template_fw0.5_occ.yml > "$seq"_fw0.5_occ.yml
   sed -i "s/SEQ_NAME/$seq/g" "$seq"_fw0.5_occ.yml
done




================================================
FILE: config/linemod/template_fw0.5.yml
================================================
vars:
  input_h: &input_h
    320 
  input_w: &input_w
    320 
  batch_size: &batch_size
    1
  descriptor_dim: &descriptor_dim
    32 
  correspondence_radius_threshold: &correspondence_radius_threshold
    0.01 #0.04 
  seq_name: &seq_name
    ["SEQ_NAME"]
BASIC:
  zoom_crop_size: [240,240]
model:
  input_h: *input_h
  input_w: *input_w
  batch_size: *batch_size
  seq_len: 2 


  network_class_name: RNNPose 
  descriptor_net:
    module_class_name: HybridFeaNet 
    
    keypoints_detector_2d:
      input_dim: 3
      descriptor_dim: *descriptor_dim 
      remove_borders: 4
      normalize_output: True 

    keypoints_detector_3d:
      #KPCONV configurations
      num_layers: 4
      KP_extent: 2.0
      batch_norm_momentum: 0.02
      use_batch_norm: true
      in_points_dim: 3
      fixed_kernel_points: 'center' #['center', 'verticals', 'none']
      KP_influence: 'linear'
      aggregation_mode: 'sum' #['closest', 'sum']
      modulated: false 
      first_subsampling_dl:  0.025
      conv_radius: 2.5
      deform_radius: 5
      in_features_dim: 1 #3
      first_feats_dim: 128
      num_kernel_points: 15
      final_feats_dim: *descriptor_dim #256 #32
      normalize_output: True 
      gnn_feats_dim: 128 #256
    context_fea_extractor_3d:
      #KPCONV configurations
      num_layers: 4
      KP_extent: 2.0
      batch_norm_momentum: 0.02
      use_batch_norm: true
      in_points_dim: 3
      fixed_kernel_points: 'center' #['center', 'verticals', 'none']
      KP_influence: 'linear'
      aggregation_mode: 'sum' #['closest', 'sum']
      modulated: false 
      first_subsampling_dl:  0.025
      conv_radius: 2.5
      deform_radius: 5
      in_features_dim: 1 #3
      first_feats_dim: 128
      num_kernel_points: 15
      final_feats_dim: 256 #*descriptor_dim #256 #32
      normalize_output: False 
      gnn_feats_dim: 128 #256
  motion_net:
    IS_CALIBRATED: True
    RESCALE_IMAGES: False
    ITER_COUNT: 4 
    RENDER_ITER_COUNT: 3 #2 #1 #3
    TRAIN_RESIDUAL_WEIGHT: 0 #0.1 
    TRAIN_FLOW_WEIGHT: 0.5 #0.1 #1 
    TRAIN_REPROJ_WEIGHT: 0 
    OPTIM_ITER_COUNT: 1
    FLOW_NET: 'raft' 
    SYN_OBSERVED: False
    ONLINE_CROP: True
    raft:
      small: False #True
      fea_net: "default"
      mixed_precision: True
      # pretrained_model: "/mnt/workspace/datasets/weights/models/raft-small.pth"
      pretrained_model: "/mnt/workspace/datasets/weights/models/raft-chairs.pth"
      input_dim: 3 
      iters: 1 
      
  loss:
    metric_loss:
      type: "normal" 
      pos_radius: *correspondence_radius_threshold # the radius used to find the positive correspondences
      safe_radius: 0.02 #0.13
      pos_margin: 0.1
      neg_margin: 1.4
      max_points: 256
      matchability_radius: 0.06
      weight: 0.001
    saliency_loss:
      loss_weight: 1
      reg_weight: 0.01
    geometric_loss:
      loss_weight: 1
      reg_weight: 0.5 #0.1 


train_config:

  optimizer: 
      adam_optimizer: 
        learning_rate: 
          one_cycle: 
            lr_maxs: []
            lr_max: 0.0001 # 
            moms: [0.95, 0.85]
            div_factor: 10.0
            pct_start: 0.01 #0.05
        amsgrad: false
        weight_decay: 0.0001 
      fixed_weight_decay: true
      use_moving_average: false

  steps: 200000 
  steps_per_eval:  10000
  loss_scale_factor: -1
  clear_metrics_every_epoch: true

train_input_reader:
  dataset:
    dataset_class_name: "LinemodDeepIMSynRealV2" 
    info_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_orig_deepim.info.train",  "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_syn_deepim.info.train", 
        "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_fusesformatted_all10k_deepim.info.train",
    ] 
    root_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine",  "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine_syn",
     "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LINEMOD/fuse_formatted/"
     ]

    model_point_dim: 3
    max_points: 20000
    seq_names: *seq_name 

  batch_size: *batch_size 
  preprocess: 
    correspondence_radius_threshold: *correspondence_radius_threshold
    num_workers: 3 
    image_scale: 1
    crop_param:
      rand_crop: false
      margin_ratio: 0.85 
      output_size: *input_h 
      crop_with_init_pose: True
    

eval_input_reader:
  dataset:
    dataset_class_name: "LinemodDeepIMSynRealV2"
    info_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_posecnn.info.eval" ] 
    root_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine" ]
    model_point_dim: 3
    max_points: 20000
    seq_names: *seq_name 
  batch_size: *batch_size 
  preprocess:
    correspondence_radius_threshold: *correspondence_radius_threshold
    num_workers: 3 
    image_scale: 1
    crop_param:
      rand_crop: false
      margin_ratio: 0.85 #0.5 
      output_size: *input_h #
      crop_with_init_pose: True
      


================================================
FILE: config/linemod/template_fw0.5_occ.yml
================================================
vars:
  input_h: &input_h
    320 
  input_w: &input_w
    320 
  batch_size: &batch_size
    1
  descriptor_dim: &descriptor_dim
    32 
  correspondence_radius_threshold: &correspondence_radius_threshold
    0.01 #0.04 
  seq_name: &seq_name
    ["SEQ_NAME"]
BASIC:
  zoom_crop_size: [240,240]
model:
  input_h: *input_h
  input_w: *input_w
  batch_size: *batch_size
  seq_len: 2 


  network_class_name: RNNPose 
  descriptor_net:
    module_class_name: HybridFeaNet 
    
    keypoints_detector_2d:
      input_dim: 3
      descriptor_dim: *descriptor_dim 
      remove_borders: 4
      normalize_output: True 

    keypoints_detector_3d:
      #KPCONV configurations
      num_layers: 4
      KP_extent: 2.0
      batch_norm_momentum: 0.02
      use_batch_norm: true
      in_points_dim: 3
      fixed_kernel_points: 'center' #['center', 'verticals', 'none']
      KP_influence: 'linear'
      aggregation_mode: 'sum' #['closest', 'sum']
      modulated: false 
      first_subsampling_dl:  0.025
      conv_radius: 2.5
      deform_radius: 5
      in_features_dim: 1 #3
      first_feats_dim: 128
      num_kernel_points: 15
      final_feats_dim: *descriptor_dim #256 #32
      normalize_output: True 
      gnn_feats_dim: 128 #256
    context_fea_extractor_3d:
      #KPCONV configurations
      num_layers: 4
      KP_extent: 2.0
      batch_norm_momentum: 0.02
      use_batch_norm: true
      in_points_dim: 3
      fixed_kernel_points: 'center' #['center', 'verticals', 'none']
      KP_influence: 'linear'
      aggregation_mode: 'sum' #['closest', 'sum']
      modulated: false 
      first_subsampling_dl:  0.025
      conv_radius: 2.5
      deform_radius: 5
      in_features_dim: 1 #3
      first_feats_dim: 128
      num_kernel_points: 15
      final_feats_dim: 256 #*descriptor_dim #256 #32
      normalize_output: False 
      gnn_feats_dim: 128 #256
  motion_net:
    IS_CALIBRATED: True
    RESCALE_IMAGES: False
    ITER_COUNT: 4 
    RENDER_ITER_COUNT: 3 #2 #1 #3
    TRAIN_RESIDUAL_WEIGHT: 0 #0.1 
    TRAIN_FLOW_WEIGHT: 0.5 #0.1 #1 
    TRAIN_REPROJ_WEIGHT: 0 
    OPTIM_ITER_COUNT: 1
    FLOW_NET: 'raft' 
    SYN_OBSERVED: False
    ONLINE_CROP: True
    raft:
      small: False #True
      fea_net: "default"
      mixed_precision: True
      # pretrained_model: "/mnt/workspace/datasets/weights/models/raft-small.pth"
      pretrained_model: "/mnt/workspace/datasets/weights/models/raft-chairs.pth"
      input_dim: 3 
      iters: 1 
      
  loss:
    metric_loss:
      type: "normal" 
      pos_radius: *correspondence_radius_threshold # the radius used to find the positive correspondences
      safe_radius: 0.02 #0.13
      pos_margin: 0.1
      neg_margin: 1.4
      max_points: 256
      matchability_radius: 0.06
      weight: 0.001
    saliency_loss:
      loss_weight: 1
      reg_weight: 0.01
    geometric_loss:
      loss_weight: 1
      reg_weight: 0.5 #0.1 


train_config:

  optimizer: 
      adam_optimizer: 
        learning_rate: 
          one_cycle: 
            lr_maxs: []
            lr_max: 0.0001 # 
            moms: [0.95, 0.85]
            div_factor: 10.0
            pct_start: 0.01 #0.05
        amsgrad: false
        weight_decay: 0.0001 
      fixed_weight_decay: true
      use_moving_average: false

  steps: 200000 
  steps_per_eval:  10000
  loss_scale_factor: -1
  clear_metrics_every_epoch: true

train_input_reader:
  dataset:
    dataset_class_name: "LinemodDeepIMSynRealV2" 
    info_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_orig_deepim.info.train",  "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_syn_deepim.info.train", 
        "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_fusesformatted_all10k_deepim.info.train",
    ] 
    root_path: ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine",  "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine_syn",
     "/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LINEMOD/fuse_formatted/"
     ]

    model_point_dim: 3
    max_points: 20000
    seq_names: *seq_name 

  batch_size: *batch_size 
  preprocess: 
    correspondence_radius_threshold: *correspondence_radius_threshold
    num_workers: 3 
    image_scale: 1
    crop_param:
      rand_crop: false
      margin_ratio: 0.85 
      output_size: *input_h 
      crop_with_init_pose: True
    

eval_input_reader:
  dataset:
    dataset_class_name: "LinemodDeepIMSynRealV2"
    info_path:  ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_bop_lmocc_pvnetdr.info.eval"]
    root_path:  ["/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/lmo"] 
    init_post_type: "PVNET_LINEMOD_OCC"
    model_point_dim: 3
    max_points: 20000
    seq_names: *seq_name 
  batch_size: *batch_size 
  preprocess:
    correspondence_radius_threshold: *correspondence_radius_threshold
    num_workers: 3 
    image_scale: 1
    crop_param:
      rand_crop: false
      margin_ratio: 0.85 #0.5 
      output_size: *input_h #
      crop_with_init_pose: True
      
      


================================================
FILE: data/__init__.py
================================================
from . import dataset
from . import linemod_dataset

================================================
FILE: data/dataset.py
================================================
import pathlib
import pickle
import time
from functools import partial

import numpy as np


REGISTERED_DATASET_CLASSES = {}


def register_dataset(cls, name=None):
    global REGISTERED_DATASET_CLASSES
    if name is None:
        name = cls.__name__
    assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}"
    REGISTERED_DATASET_CLASSES[name] = cls
    return cls


def get_dataset_class(name):
    global REGISTERED_DATASET_CLASSES
    assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}"
    return REGISTERED_DATASET_CLASSES[name]


class Dataset(object):
    NumPointFeatures = -1

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def _read_data(self, query):

        raise NotImplementedError

    def evaluation(self, dt_annos, output_dir):
        """Dataset must provide a evaluation function to evaluate model."""
        raise NotImplementedError


================================================
FILE: data/linemod/linemod_config.py
================================================
import numpy as np
diameters = {
    'cat': 15.2633,
    'ape': 9.74298,
    'benchvise': 28.6908,
    'bowl': 17.1185,
    'cam': 17.1593,
    'camera': 17.1593,
    'can': 19.3416,
    'cup': 12.5961,
    'driller': 25.9425,
    'duck': 10.7131,
    'eggbox': 17.6364,
    'glue': 16.4857,
    'holepuncher': 14.8204,
    'iron': 30.3153,
    'lamp': 28.5155,
    'phone': 20.8394
}

linemod_cls_names = ['ape', 'cam', 'cat', 'duck', 'glue', 'iron', 'phone', 'benchvise', 'can', 'driller', 'eggbox', 'holepuncher', 'lamp']

linemod_K = np.array([[572.4114, 0., 325.2611],
                  [0., 573.57043, 242.04899],
                  [0., 0., 1.]])


blender_K = np.array([[700., 0., 320.],
                    [0., 700., 240.],
                    [0., 0., 1.]])

================================================
FILE: data/linemod_dataset.py
================================================
import numpy as np 
import random
import os 
from data.dataset import Dataset, register_dataset
import pickle
import PIL
import cv2
import torch
import time
import scipy

from utils.geometric import range_to_depth, render_pointcloud
from .transforms import make_transforms 
from thirdparty.kpconv.lib.utils import square_distance
from utils.geometric import rotation_angle
# from utils.visualize import *
from transforms3d.quaternions import mat2quat, quat2mat, qmult
from transforms3d.euler import mat2euler, euler2mat, euler2quat, quat2euler
import math
from config.default import get_cfg

CURRENT_DIR=os.path.dirname(os.path.abspath(__file__))

try:
    from pytorch3d.io import load_obj, load_ply
except:
    print("Warning: error occurs when importing pytorch3d ")
    pass


def se3_q2m(se3_q):
    assert se3_q.size == 7
    se3_mx = np.zeros((3, 4))
    # quat = se3_q[0:4] / LA.norm(se3_q[0:4])
    quat = se3_q[:4]
    R = quat2mat(quat)
    se3_mx[:, :3] = R
    se3_mx[:, 3] = se3_q[4:]
    return se3_mx

def info_convertor(info,):
    """
        [Transform the original kitti info file]
    """

    seqs = info.keys() #['cat']#
    seq_lengths = [len(info[i]) for i in seqs]
    data = []
    for seq in seqs:
        print(seq)
        data.append(info[seq])

    new_infos = {
        "seqs": list(seqs),
        "seq_lengths": seq_lengths,
        "data": data
    }
    return new_infos

def resize(im, target_size, max_size, stride=0, interpolation=cv2.INTER_LINEAR):
    """
    only resize input image to target size and return scale
    :param im: BGR image input by opencv
    :param target_size: one dimensional size (the short side)
    :param max_size: one dimensional max size (the long side)
    :param stride: if given, pad the image to designated stride
    :param interpolation: if given, using given interpolation method to resize image
    :return:
    """
    im_shape = im.shape
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])
    im_scale = float(target_size) / float(im_size_min)
    # prevent bigger axis from being more than max_size:
    if np.round(im_scale * im_size_max) > max_size:
        im_scale = float(max_size) / float(im_size_max)
    im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=interpolation)

    if stride == 0:
        return im, im_scale
    else:
        # pad to product of stride
        im_height = int(np.ceil(im.shape[0] / float(stride)) * stride)
        im_width = int(np.ceil(im.shape[1] / float(stride)) * stride)
        im_channel = im.shape[2]
        padded_im = np.zeros((im_height, im_width, im_channel))
        padded_im[: im.shape[0], : im.shape[1], :] = im
        return padded_im, im_scale
def sample_poses(pose_tgt):
    SYN_STD_ROTATION = 15
    SYN_STD_TRANSLATION = 0.01
    ANGLE_MAX=45
    pose_src = pose_tgt.copy()
    num = pose_tgt.shape[0]
    for i in range(num):
        euler = mat2euler(pose_tgt[i, :3, :3])
        euler += SYN_STD_ROTATION * np.random.randn(3) * math.pi / 180.0
        pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])

        pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)
        pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)
        pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)

        r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)/math.pi*180

        while r_dist > ANGLE_MAX:#or not (16 < center_x < (640 - 16) and 16 < center_y < (480 - 16)):
            # print("r_dist > ANGLE_MAX, resampling...")
            print("Too large angular differences. Resample the pose...")
            euler = mat2euler(pose_tgt[i, :3, :3])
            euler += SYN_STD_ROTATION * np.random.randn(3) * math.pi / 180.0
            pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])

            pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)
            pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)
            pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)

            r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)*math.pi/180
    return pose_src.squeeze()




@register_dataset
class LinemodDeepIMSynRealV2(Dataset):
    # use deepim 3d model for geometric feature extraction, mingle the synthetic and real data  
    def __init__(self, root_path,
                 info_path, model_point_dim,
                 is_train,
                 prep_func=None,
                 seq_names=None, 
                 cfg={}
                 ):
        super().__init__()

        assert info_path is not None
        assert isinstance(root_path, (tuple, list)) and isinstance(info_path, (tuple, list))
        assert len(root_path) == len(info_path)
        print("Info:",info_path)
        # assert split in ['train', 'val', 'test']
        self.is_train = is_train
        self.VOC_ROOT = get_cfg('DATA').VOC_ROOT#"/DATA/yxu/LINEMOD_DEEPIM/"

        infos=[]
        for ipath in info_path:
            with open(ipath, 'rb') as f:
                info = pickle.load(f)

                if seq_names is not None:
                    for k in list(info.keys()):
                        if k not in seq_names:
                            del info[k]
                infos.append( info_convertor(info) )

        #merge multiple infos 
        self.infos = infos[0]
        self.infos['dataset_idx'] = [0]*len(self.infos['seqs'])
        for i, info in enumerate(infos[1:]):
            for k in self.infos:
                if k == 'dataset_idx':
                    self.infos[k].extend([i+1]*len(info['seqs']))
                else:
                    self.infos[k].extend(info[k])


        self.root_paths = root_path
        self.model_point_dim = model_point_dim
        # self.max_points=max_points#30000
        self.prep_func=prep_func
        # self.rgb_transformer = None #make_transforms(None, is_train=is_train)
        self.rgb_transformer = make_transforms(None, is_train=is_train)
        print("dataset size:",self.__len__())

        self.init_pose_type = cfg.get("init_post_type", "POSECNN_LINEMOD" ) 
        # self.init_pose_type = cfg.get("init_post_type", "PVNET_LINEMOD_OCC" ) 
#         self.init_pose_type = cfg.get("init_post_type", "PVNET_LINEMOD" ) 
        print("INIT_POSE_TYPE:", self.init_pose_type)
        #Load posecnn results
        if not self.is_train:
            with open(f"{CURRENT_DIR}/../EXPDATA/init_poses/linemod_posecnn_results.pkl", 'rb') as f:
                self.pose_cnn_results_test_posecnn=pickle.load(f)
            try:
                if self.init_pose_type == "POSECNN_LINEMOD":
                    #load posecnn results 
                    self.pose_cnn_results_test=self.pose_cnn_results_test_posecnn
                elif self.init_pose_type =="PVNET_LINEMOD":
                    self.pose_cnn_results_test=np.load(f"{CURRENT_DIR}/../EXPDATA/init_poses/pvnet/pvnet_linemod_test.npy", allow_pickle=True).flat[0]
                elif self.init_pose_type =="PVNET_LINEMOD_OCC":
                    self.pose_cnn_results_test=np.load(f"{CURRENT_DIR}/../EXPDATA/init_poses/pvnet/pvnet_linemodocc_test.npy", allow_pickle=True).flat[0]
                else: 
                    raise NotImplementedError
            except:
                print("Loading posecnn results failed!")
                self.pose_cnn_results_test=None
            try:
                # self.blender_to_bop_pose=np.load(f"{CURRENT_DIR}/../EXPDATA/init_poses/metricpose/blender2bop_RT.npy", allow_pickle=True).flat[0]
                self.blender_to_bop_pose=np.load(f"{CURRENT_DIR}/../EXPDATA/init_poses/pose_conversion/blender2bop_RT.npy", allow_pickle=True).flat[0]
            except:
                print("Loading pose conversion matrix failed!")
                self.blender_to_bop_pose=None 
                
        else:
            self.pose_cnn_results_test=None
            self.blender_to_bop_pose=None
        
    def load_random_background(self, im_observed, mask):
        VOC_root = os.path.join(self.VOC_ROOT, "VOCdevkit/VOC2012")
        VOC_image_set_dir = os.path.join(VOC_root, "ImageSets/Main")
        VOC_bg_list_path = os.path.join(VOC_image_set_dir, "diningtable_trainval.txt")
        with open(VOC_bg_list_path, "r") as f:
            VOC_bg_list = [
                line.strip("\r\n").split()[0] for line in f.readlines() if line.strip("\r\n").split()[1] == "1"
            ]
        height, width, channel = im_observed.shape
        target_size = min(height, width)
        max_size = max(height, width)
        observed_hw_ratio = float(height) / float(width)

        k = random.randint(0, len(VOC_bg_list) - 1)
        bg_idx = VOC_bg_list[k]
        bg_path = os.path.join(VOC_root, "JPEGImages/{}.jpg".format(bg_idx))
        bg_image = cv2.imread(bg_path, cv2.IMREAD_COLOR)[...,::-1] #RGB
        bg_h, bg_w, bg_c = bg_image.shape
        bg_image_resize = np.zeros((height, width, channel), dtype="uint8")
        if (float(height) / float(width) < 1 and float(bg_h) / float(bg_w) < 1) or (
            float(height) / float(width) >= 1 and float(bg_h) / float(bg_w) >= 1
        ):
            if bg_h >= bg_w:
                bg_h_new = int(np.ceil(bg_w * observed_hw_ratio))
                if bg_h_new < bg_h:
                    bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]
                else:
                    bg_image_crop = bg_image
            else:
                bg_w_new = int(np.ceil(bg_h / observed_hw_ratio))
                if bg_w_new < bg_w:
                    bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]
                else:
                    bg_image_crop = bg_image
        else:
            if bg_h >= bg_w:
                bg_h_new = int(np.ceil(bg_w * observed_hw_ratio))
                bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]
            else:  # bg_h < bg_w
                bg_w_new = int(np.ceil(bg_h / observed_hw_ratio))
                print(bg_w_new)
                bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]

        bg_image_resize_0, _ = resize(bg_image_crop, target_size, max_size)
        h, w, c = bg_image_resize_0.shape
        bg_image_resize[0:h, 0:w, :] = bg_image_resize_0

        # add background to image_observed
        res_image = bg_image_resize.copy()
        res_image[mask>0]=im_observed[mask>0]

        # im_observed = res_image
        return res_image

    def _read_data(self, idx):
        """
        info structure:
        {
            'cat':[
                {
                "index": idx,
                "model_path": str,
                "rgb_path": str,
                "depth_path": str,
                "RT": np.array([3,4]),
                "K":  np.array([3,3]),
                },
                {
                "index": idx,
                "model_path": str,
                "rgb_path": str,
                "depth_path": str,
                "RT": np.array([3,4]),
                "K":  np.array([3,3]),
                }
            ...
            ],
            'dog':[

            ]
            ...
        }

        """

        if isinstance(idx, (tuple, list)):
            idx, seed = idx
        else:
            seed = None

        seq_lengths = np.array(self.infos['seq_lengths'])
        seq_lengths_cum = np.cumsum(seq_lengths)
        seq_lengths_cum = np.insert(seq_lengths_cum, 0, 0)  # insert a dummy 0
        seq_idx = np.nonzero(seq_lengths_cum > idx)[0][0]-1

        frame_idx = idx - seq_lengths_cum[seq_idx]

        info = self.infos["data"][seq_idx]
        dataset_idx = self.infos["dataset_idx"][seq_idx]
        

        model_points_path = os.path.join(f'{os.path.dirname(__file__)}/../EXPDATA/LM6d_converted/models/{self.infos["seqs"][seq_idx]}/textured.obj' ) # TODO: need check

        rgb_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['rgb_observed_path']) 
        depth_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['depth_gt_observed_path']) 

        if info[frame_idx].get('rgb_noisy_rendered', None) is not None:
            rgb_noisy_rendered_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['rgb_noisy_rendered']) 
        else:
            rgb_noisy_rendered_path = None
        if info[frame_idx].get('depth_noisy_rendered', None) is not None:
            depth_noisy_rendered_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['depth_noisy_rendered']) 
        else:
            depth_noisy_rendered_path = None

        if info[frame_idx].get('pose_noisy_rendered', None) is not None:
            rendered_RT = info[frame_idx]['pose_noisy_rendered'].astype(np.float32)
        # else:
        elif self.is_train:
            rendered_RT = sample_poses( info[frame_idx]['gt_pose'].astype(np.float32)[None] )

        K = info[frame_idx]['K'].astype(np.float32)
        RT = info[frame_idx]['gt_pose'].astype(np.float32) #[R,t]

        # evaluation 
        if not self.is_train:
            if self.pose_cnn_results_test is not None:
                class_name=self.infos["seqs"][seq_idx]

                if self.init_pose_type == "PVNET_LINEMOD":
                    try:
                        posecnn_RT = self.pose_cnn_results_test[class_name][frame_idx] # if self.pose_cnn_results_test is not None else np.zeros_like(RT)
                        #Transformations are needed as the pvnet has a different coordinate system. 
                        posecnn_RT[:3,:3] =  posecnn_RT[:3,:3]@self.blender_to_bop_pose[class_name][:3,:3].T
                        posecnn_RT[:3,3:] =  -posecnn_RT[:3,:3] @self.blender_to_bop_pose[class_name][:3,3:]  + posecnn_RT[:3,3:] 
                    except:
                        print("Warning: frame_idx is out of the range of self.pose_cnn_results_test!", flush=True)
                        posecnn_RT= se3_q2m(self.pose_cnn_results_test_posecnn[class_name][frame_idx]['pose']) #np.zeros_like(RT)
                elif self.init_pose_type =="POSECNN_LINEMOD":
                    posecnn_RT= se3_q2m(self.pose_cnn_results_test_posecnn[class_name][frame_idx]['pose']) 
                elif self.init_pose_type == "PVNET_LINEMOD_OCC":
                    try:
                        posecnn_RT = self.pose_cnn_results_test[class_name][frame_idx].copy()# if self.pose_cnn_results_test is not None else np.zeros_like(RT)
                        #Transformations are needed as the pvnet has a different coordinate system. 
                        posecnn_RT[:3,:3] =  posecnn_RT[:3,:3]@self.blender_to_bop_pose[class_name][:3,:3].T
                        posecnn_RT[:3,3:] =  -posecnn_RT[:3,:3] @self.blender_to_bop_pose[class_name][:3,3:]  + posecnn_RT[:3,3:] 
                    except:
                        # print(frame_idx)
                        raise
                else:
                    raise NotImplementedError 
                
                rendered_RT = posecnn_RT
            else:
                print("Warning: fail to load cnn poses!", flush=True)
                posecnn_RT = np.zeros_like(RT)
        else:
            posecnn_RT = np.zeros_like(RT)
        
        #add noise--for testing purpose only, should always be disabled in normal cases 
#         rot_std=0; trans_std=0.04; ang_max=1000;
#         print(f"Add pose noises rot_std={rot_std}, trans_std={trans_std}", flush=True)
#         rendered_RT=sample_poses(rendered_RT[None], rot_std=rot_std, trans_std=trans_std, ang_max=ang_max) 

        # Regularize the matrix to be a valid rotation
        rendered_RT[:3,:3] = rendered_RT[:3,:3]@ np.linalg.inv(scipy.linalg.sqrtm(rendered_RT[:3,:3].T@rendered_RT[:3,:3]))
        
        # model_points = np.fromfile(
        #     str(model_points_path), dtype=np.float32, count=-1).reshape([-1, self.model_point_dim]) # N x model_point_dim
        model_points, _,_ = load_obj(str(model_points_path) )
        model_points = model_points.numpy()
        
        visb = model_points[:,-1:]  # N x model_point_dim

        model_point_features=np.ones_like(model_points[:,:1]).astype(np.float32)


        rgb =  np.asarray(PIL.Image.open(rgb_path))

        if depth_path.endswith('.npy'):
            depth = np.load(depth_path) # blender 
        else:
            depth = cv2.imread(depth_path, -1) /1000.

        if self.is_train and "LM6d_refine_syn" in self.root_paths[dataset_idx]: #synthetic data
            rgb = self.load_random_background(rgb, mask=(depth>0)[...,None].repeat(rgb.shape[-1], axis=-1) )


        
        rgb_rendered =  np.asarray(PIL.Image.open(rgb_noisy_rendered_path)) if rgb_noisy_rendered_path is not None else None
        depth_rendered = np.asarray(PIL.Image.open(depth_noisy_rendered_path))/1000 if depth_noisy_rendered_path is not None else None #TODO: need check

        ren_mask = render_pointcloud(model_points, rendered_RT[None],K=K[None], render_image_size=rgb.shape[:2] ).squeeze()>0
        # depth = range_to_depth(depth<1, depth*2, K)

        return {
            "class_name":  self.infos["seqs"][seq_idx], 
            "idx": idx,
            "model_points": model_points,
            "visibility": visb,
            "model_point_features":model_point_features,
            "image": rgb,
            "depth": depth,
            "mask": depth>0,
            "rendered_image": rgb_rendered,
            "rendered_depth": depth_rendered,
            "K": K,
            "RT": RT,
            "rendered_RT": rendered_RT.astype(np.float32),
            "ren_mask":ren_mask,
            "POSECNN_RT": posecnn_RT.astype(np.float32), #for test, TODO
            "scale": 1 # model_scale * scale = depth_scale
        }



    def __getitem__(self, idx):

        data=self._read_data(idx) 
        try:
            data_p=self.prep_func(data, rand_rgb_transformer=self.rgb_transformer, find_2d3d_correspondence=self.is_train )
        except Exception as e: 
            if e.args[0] in ["Too few correspondences are found!"] :
                if isinstance(idx, (tuple, list)):
                    # idx, seed = idx
                    idx = [(idx[0]+1)%self.__len__(), idx[1]]
                else:
                    idx = (idx+1) %self.__len__()
                data_p= self.__getitem__(idx )
            else:
                raise ValueError

        return data_p

    def __len__(self):
        return np.sum(self.infos['seq_lengths'])


================================================
FILE: data/preprocess.py
================================================
import open3d as o3d
import copy
import os

import pathlib
import pickle
import time
from collections import defaultdict
from functools import partial

import cv2
import numpy as np
import quaternion
from skimage import io as imgio
from utils.timer import simple_timer

import matplotlib.pyplot as plt
from collections.abc import Iterable
import torch
import torch.nn.functional as F
import quaternion

from functools import partial
import thirdparty.kpconv.cpp_wrappers.cpp_subsampling.grid_subsampling as cpp_subsampling
import thirdparty.kpconv.cpp_wrappers.cpp_neighbors.radius_neighbors as cpp_neighbors
from thirdparty.kpconv.lib.timer import Timer
from utils.geometric import range_to_depth, mask_depth_to_point_cloud
from utils.furthest_point_sample import fragmentation_fps
from utils.rand_utils import truncated_normal



def merge_batch(batch_list):
    # [batch][key][seq]->example[key][seq][batch]
    # Or [batch][key]->example[key][batch]
    example_merged = defaultdict(list)
    for example in batch_list:  # batch dim
        for k, v in example.items():  # key dim
            # assert isinstance(v, list)
            if isinstance(v, list):
                seq_len = len(v)
                if k not in example_merged:
                    example_merged[k] = [[] for i in range(seq_len)]
                for i, vi in enumerate(v):  # seq dim
                    example_merged[k][i].append(vi)

            else:
                example_merged[k].append(v)

    ret = {}
    for key, elems in example_merged.items():
        if key in ['model_points', "original_model_points", 'visibility']:
            # concat the points of lenghts (N1,N2...) to a longer one with length (N1+N2+...)
            ret[key] = np.concatenate(elems, axis=0)
            # record the point numbers for original batches
            ret['batched_model_point_lengths'] = np.array(
                [len(p) for p in elems], dtype=np.int32)
        elif key in ['rand_model_points', ]:
            # concat the points of lenghts (N1,N2...) to a longer one with length (N1+N2+...)
            ret[key] = np.concatenate(elems, axis=0)
            # record the point numbers for original batches
            ret['batched_rand_model_point_lengths'] = np.array(
                [len(p) for p in elems], dtype=np.int32)
        elif key in ['model_point_features']:
            ret[key] = np.concatenate(elems, axis=0)

        # ['odometry/tq','odometry/RT','odometry/invRT' ]:
        elif key in ['image', 'depth', 'K', 'RT', 'original_RT' ,'rand_RT', 'correspondences_2d3d', 'scale',  'POSECNN_RT','rendered_image', 'rendered_depth', 'rendered_RT', '3d_keypoint_inds', '3d_keypoints', 'mask', 'ren_mask']: #'depth_coords2d','lifted_points', 
            try:
                ret[key] = np.stack(elems, axis=0)
            except:
                print(key, flush=True)
                raise
        elif key == 'metrics':
            ret[key] = elems
        else:
            ret[key] = []
            for e in elems:
                ret[key].append(e)

    return ret


def get_correspondences(src_pcd, tgt_pcd, search_voxel_size, K=None, trans=None):
    if trans is not None:
        src_pcd.transform(trans)
    pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd)

    correspondences = []
    for i, point in enumerate(src_pcd.points):
        [count, idx, _] = pcd_tree.search_radius_vector_3d(
            point, search_voxel_size)
        if K is not None:
            idx = idx[:K]
        for j in idx:
            correspondences.append([i, j])

    correspondences = np.array(correspondences)
    # correspondences = torch.from_numpy(correspondences)
    return correspondences


def to_pcd(xyz):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    return pcd


def to_tsfm(rot, trans):
    tsfm = np.eye(4)
    tsfm[:3, :3] = rot
    tsfm[:3, 3] = trans.flatten()
    return tsfm


def CameraIntrinsicUpdate(old_K, aug_param):
    '''
    old_K: array of shape (N,3,3), the old camera intrinsic parameters
    aug_pram: dict, the data augmentation parameters
    '''
    aug_type = aug_param['aug_type']
    assert aug_type in ['crop', 'scale', 'flip']

    new_K = np.copy(old_K)
    if aug_type == 'crop':
        cx, cy = aug_param['crop/left_top_corner']  # x,y
        new_K[..., 0, 2] = new_K[..., 0, 2] - cx
        new_K[..., 1, 2] = new_K[..., 1, 2] - cy
    elif aug_type == 'scale':
        s_x, s_y = aug_param['scale/scale']
        new_K[..., 0, 0] = s_x * new_K[..., 0, 0]
        new_K[..., 1, 1] = s_y * new_K[..., 1, 1]

        new_K[..., 0, 2] = s_x * new_K[..., 0, 2]
        new_K[..., 1, 2] = s_y * new_K[..., 1, 2]
    elif aug_type == 'flip':
        w = aug_param['flip/width']
        # h = aug_param['flip/heigh']
        new_K[..., 0, 2] = w - new_K[..., 0, 2]  # px' = w-px
        # new_K[...,1,2] = h- new_K[...,1,2]
        new_K[..., 0, 0] = - new_K[..., 0, 0]  # fx' = -fx

    return new_K


def crop_transform(images, depths, Ks, crop_param, ):
    assert(len(images) == len(depths) == len(Ks))

    crop_type = crop_param["crop_type"]
    assert(crop_type in ["fixed", "center", "random"])

    crop_size = crop_param["crop_size"]
    iheight, iwidth = images[0].shape[:2]

    if crop_type == "fixed":
        lt_corner = crop_param["lt_corner"]
        op = transforms.Crop(
            lt_corner[0], lt_corner[1], crop_size[0], crop_size[1])
    elif crop_type == "center":
        op = transforms.CenterCrop(crop_size)

        ci, cj, _, _ = op.get_params(images[0], crop_size)
        lt_corner = ci, cj

    elif crop_type == "random":
        op = transforms.RandomCrop((iheight, iwidth), crop_size)

        lt_corner = op.i, op.j

    for i, _ in enumerate(images):
        images[i] = op(images[i])
        depths[i] = op(depths[i])

        Ks[i] = CameraIntrinsicUpdate(Ks[i],
                                      {"aug_type": "crop", "crop/left_top_corner": (lt_corner[1], lt_corner[0])})

    return images, depths, Ks


# def patch_crop(image, depth, mask, K_old, margin_ratio=0.2, output_size=128, offset_ratio=(0,0),bbox=None, mask_depth=True):
def patch_crop(image, depth, mask, K_old, margin_ratio=0.2, output_size=128, offset_ratio=(0,0),bbox=None, mask_depth=False):
    '''
        image: HxWx3
        mask: HxW
        K_old: 3x3
        offset: (offset_h, offset_w)
    '''

    H, W, _ = image.shape
    
    mask = mask.astype('uint8')*255
    if bbox is None:
        _x, _y, _w, _h = cv2.boundingRect(mask)
    else:
        _x, _y, _w, _h = bbox[1], bbox[0], bbox[3]-bbox[1], bbox[2]-bbox[0]

    # center = [_x+_w/2, _y+_h/2]
    center = [_x+_w/2+offset_ratio[1]*_w, _y+_h/2+offset_ratio[0]*_h ]

    L = int(max(_w, _h) * (1+2*margin_ratio))
    
    if L<0:
        #TODO
        print(mask.sum(), depth.sum(), '!!!', flush=True)
        L=128

    x = max(0, int(center[0] - L/2))
    y = max(0, int(center[1] - L/2))

    crop = image[y:y+L, x:x+L]
    # only keep the ROI depth

    if mask_depth:
        depth[mask < 1] = 0 # removed by dy at 0810
    depth_crop = depth[y:y+L, x:x+L]
    mask_crop = mask[y:y+L, x:x+L]

    

    # w=h=int ((1+2*margin_ratio)*L) # actual crop size
    w = h = L  # actual crop size
    # automatically handle the "out of range" problem
    patch = np.zeros([h, w, 3], dtype=image.dtype)
    # depth_patch = np.ones([h, w], dtype=depth.dtype)
    depth_patch = np.zeros([h, w], dtype=depth.dtype)
    mask_patch = np.zeros([h, w], dtype=depth.dtype)

    try:
        xp = 0
        yp = 0
        patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = crop
        depth_patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = depth_crop
        mask_patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = mask_crop
    except:
        import pdb
        pdb.set_trace()
    patch = cv2.resize(patch, (output_size, output_size),
                       interpolation=cv2.INTER_LINEAR)
    depth_patch = cv2.resize(
        depth_patch, (output_size, output_size), interpolation=cv2.INTER_NEAREST)
    mask_patch = cv2.resize(
        mask_patch, (output_size, output_size), interpolation=cv2.INTER_NEAREST)

    # update the intrinsic parameters
    K_new = np.zeros_like(K_old)
    scale = output_size/L
    K_new[0, 2] = (K_old[0, 2]-x)*scale
    K_new[1, 2] = (K_old[1, 2]-y)*scale
    K_new[0, 0] = K_old[0, 0]*scale
    K_new[1, 1] = K_old[1, 1]*scale
    K_new[2, 2] = 1

    # return patch, depth_patch, K_new
    return patch, depth_patch, mask_patch, K_new


def preprocess_deepim(
    input_dict,
    max_points,
    correspondence_radius,
    normalize_model=True,
    rand_transform_model=False,  # False,#True,
    rand_rgb_transformer=None,
    image_scale=None,
    patch_cropper=None,  # func patch_crop(...)
    
):
    output_dict = copy.deepcopy(input_dict)

    ################################### process 3D point clouds ###################################

    if (output_dict['model_points'].shape[0] > max_points):
        # if(output_dict['model_points'].shape[0] > 20000):
        idx = np.random.permutation(
            output_dict['model_points'].shape[0])[:max_points]
        print(idx, output_dict['model_points'].shape, flush=True)
        output_dict['model_points'] = output_dict['model_points'][idx]
        output_dict['model_point_features'] = output_dict['model_point_features'][idx]

    output_dict['original_RT'] = copy.deepcopy(output_dict['RT'])
    if normalize_model:
        points = output_dict['model_points']
        mean = points.mean(axis=0)
        scope = points.max(axis=0)-points.min(axis=0)
        points_normalize = (points-mean)/scope.max()
        # points_normalize.tofile(bin_save_path)
        # modify the extrinsic parameters
        output_dict['RT'][:, 3:] = output_dict['RT'][:, :3] @ mean[:,
                                                                   None] + output_dict['RT'][:, 3:]  # 3x3 @ 3x1 + 3x1
        # input_dict['RT'][:,:3] *=scope.max()
        output_dict['scale'] = scope.max()
        output_dict['original_model_points'] = output_dict['model_points']
        output_dict['model_points'] = points_normalize


    if rand_transform_model:
        points = output_dict['model_points']
        rand_quat = np.random.randn(1, 4)
        rand_quat = rand_quat/np.linalg.norm(rand_quat, axis=-1)
        rand_rot = quaternion.as_rotation_matrix(
            quaternion.from_float_array(rand_quat)).squeeze()  # 3x3
        output_dict['rand_model_points'] = (
            rand_rot@ points.T).T.astype(np.float32)
        output_dict['rand_RT'] = copy.deepcopy(output_dict['RT'])
        # output_dict['RT'][:,:3]@rand_rot.T
        output_dict['rand_RT'][:, :3] = rand_rot
        output_dict['rand_RT'][:, 3] = 0

    ################################### process 2D images ###################################
    # carve out image patches
    if patch_cropper is not None:
        ref_mask = output_dict['depth'] > 0
        output_dict['image'], output_dict['depth'], output_dict['K'] = patch_cropper(
            output_dict['image'], output_dict['depth'],  ref_mask, output_dict['K'])

        output_dict['rendered_image'], output_dict['rendered_depth'], _ = patch_cropper(
            output_dict['rendered_image'], output_dict['rendered_depth'],  ref_mask, output_dict['K'].copy() )

    # rescale image
    if image_scale is not None:
        output_dict['image'] = cv2.resize(output_dict['image'],
                                          (output_dict['image'].shape[1]*image_scale,
                                           output_dict['image'].shape[0]*image_scale),
                                          interpolation=cv2.INTER_AREA)
        output_dict['depth'] = cv2.resize(output_dict['depth'],
                                          (output_dict['depth'].shape[1]*image_scale,
                                           output_dict['depth'].shape[0]*image_scale),
                                          interpolation=cv2.INTER_NEAREST)
        output_dict['K'][:2] = output_dict['K'][:2]*image_scale

    # lift depth
    depth = output_dict['depth'].squeeze()  # H,W
    depth_pts, depth_coords2d = mask_depth_to_point_cloud(
        depth != 0, depth, output_dict['K'])
    depth_pts = (output_dict['RT'][:, :3].T@(depth_pts.T - output_dict['RT']
                                             [:, 3:])).T / output_dict['scale']  # transformed to the model frame

    # find 2d-3d correspondences
    tsfm = np.eye(4)
    tsfm[:3] = output_dict['RT']
    model_pcd = output_dict['model_points']

    correspondences_2d3d = get_correspondences(
        to_pcd(depth_pts), to_pcd(model_pcd),  correspondence_radius, K=5)
    if len(correspondences_2d3d.shape) < 2 or len(correspondences_2d3d) < 10:
        print(depth_pts.shape, model_pcd.shape)
        print("correspondences_2d3d.shape:",
              correspondences_2d3d.shape, flush=True)
        # raise ValueError("Too few correspondences are found!")
        raise Exception("Too few correspondences are found!")

    output_dict['depth_coords2d'] = depth_coords2d
    output_dict['lifted_points'] = depth_pts
    # output_dict['correspondences_2d3d'] = np.zeros(1)#correspondences_2d3d
    output_dict['correspondences_2d3d'] = correspondences_2d3d

    if rand_rgb_transformer is not None:
        output_dict['image'], _, _ = rand_rgb_transformer(output_dict['image'])
    # TO TENSOR
    output_dict['image'] = (output_dict['image'].astype(
        np.float32)/255.0).transpose([2, 0, 1])  # .mean(axis=0, keepdims=True) # 1,H,W
    output_dict['depth'] = output_dict['depth'].astype(np.float32)[
        None]  # 1,H,W

    return output_dict

def preprocess(
    input_dict,
    max_points,
    correspondence_radius,
    normalize_model=True,
    rand_transform_model=False, 
    rand_rgb_transformer=None,
    image_scale=None,
    crop_param=None,
    kp_3d_param=None,
    use_coords_as_3d_feat=False,
    find_2d3d_correspondence=True,
    
):
    output_dict = copy.deepcopy(input_dict)

    ################################### process 3D point clouds ###################################
    if use_coords_as_3d_feat:
        output_dict['model_point_features'] = output_dict['model_points'][:,:3]

    if (output_dict['model_points'].shape[0] > max_points):
        # if(output_dict['model_points'].shape[0] > 20000):
        idx = np.random.permutation(
            output_dict['model_points'].shape[0])[:max_points]
        print(idx, output_dict['model_points'].shape, flush=True)
        output_dict['model_points'] = output_dict['model_points'][idx]
        output_dict['model_point_features'] = output_dict['model_point_features'][idx]

    output_dict['original_RT'] = copy.deepcopy(output_dict['RT'])
    output_dict['original_model_points'] = output_dict['model_points']
    if normalize_model:
        points = output_dict['model_points']
        mean = points.mean(axis=0)
        scope = points.max(axis=0)-points.min(axis=0)
        points_normalize = (points-mean)/scope.max()
        # modify the extrinsic parameters
        output_dict['RT'][:, 3:] = output_dict['RT'][:, :3] @ mean[:,
                                                                   None] + output_dict['RT'][:, 3:]  # 3x3 @ 3x1 + 3x1
        output_dict['scale'] = scope.max()
        output_dict['model_points'] = points_normalize


    if rand_transform_model:
        points = output_dict['model_points']
        rand_quat = np.random.randn(1, 4)
        rand_quat = rand_quat/np.linalg.norm(rand_quat, axis=-1)
        rand_rot = quaternion.as_rotation_matrix(
            quaternion.from_float_array(rand_quat)).squeeze()  # 3x3
        output_dict['rand_model_points'] = (
            rand_rot@ points.T).T.astype(np.float32)
        output_dict['rand_RT'] = copy.deepcopy(output_dict['RT'])
        output_dict['rand_RT'][:, :3] = rand_rot
        output_dict['rand_RT'][:, 3] = 0

    ################################### process 2D images ###################################
    #crop image
    if crop_param is not None:# and output_dict['mask'].sum()>0:
        #without random cropping
        if not crop_param.rand_crop: 
            if crop_param.get("crop_with_init_pose", False):
                # bbox= output_dict.get('bbox', None)
                bbox=None
                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['ren_mask'],
                                K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size,  bbox=bbox
                                                )
            elif crop_param.get("crop_with_rand_bbox_shift", True): 
                bbox= output_dict.get('bbox', None)
                # offset_ratio= [truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio, truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio] 
                offset_ratio= [truncated_normal(0,0.33,-1,1)*1, truncated_normal(0,0.33,-1,1)*1] 
                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],
                                                    K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size, offset_ratio=offset_ratio, bbox=output_dict.get('bbox', None) 
                                                )
            else:
                bbox= output_dict.get('bbox', None)
                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],
                                                    K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size, bbox=output_dict.get('bbox', None) 
                                                )
        else:
            margin_ratio= truncated_normal(0.5, 0.5, 0, 1) *crop_param.max_rand_margin_ratio
            offset_ratio= [truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio, truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio] 
            output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],
                                                K_old=output_dict['K'], margin_ratio=margin_ratio, output_size=crop_param.output_size, offset_ratio=offset_ratio,  bbox=output_dict.get('bbox', None) 
                                              )
            
    # rescale image
    if image_scale is not None:
        output_dict['image'] = cv2.resize(output_dict['image'],
                                          (output_dict['image'].shape[1]*image_scale,
                                           output_dict['image'].shape[0]*image_scale),
                                          interpolation=cv2.INTER_AREA)
        output_dict['depth'] = cv2.resize(output_dict['depth'],
                                          (output_dict['depth'].shape[1]*image_scale,
                                           output_dict['depth'].shape[0]*image_scale),
                                          interpolation=cv2.INTER_NEAREST)
        output_dict['K'][:2] = output_dict['K'][:2]*image_scale

    # lift depths
    depth = output_dict['depth'].squeeze()  # H,W
    depth_pts, depth_coords2d = mask_depth_to_point_cloud(
        depth != 0, depth, output_dict['K'])

    depth_pts = (output_dict['RT'][:, :3].T@(depth_pts.T - output_dict['RT']
                                             [:, 3:])).T / output_dict['scale']  # transformed to the model frame

    # find 2d-3d correspondences
    if find_2d3d_correspondence:
        tsfm = np.eye(4)
        tsfm[:3] = output_dict['RT']
        model_pcd = output_dict['model_points']
        correspondences_2d3d = get_correspondences(
            to_pcd(depth_pts), to_pcd(model_pcd),  correspondence_radius, K=5)
        if len(correspondences_2d3d.shape) < 2 or len(correspondences_2d3d) < 10:# or ( "mask" in output_dict and output_dict['mask'].sum()<10 ) :
            print(depth_pts.shape, model_pcd.shape)
            print("correspondences_2d3d.shape:",
                  correspondences_2d3d.shape, flush=True)
            raise Exception("Too few correspondences are found!")

        output_dict['depth_coords2d'] = depth_coords2d
        output_dict['lifted_points'] = depth_pts
        output_dict['correspondences_2d3d'] = correspondences_2d3d
    else:
        output_dict['depth_coords2d'] = depth_coords2d
        output_dict['lifted_points'] = depth_pts
        output_dict['correspondences_2d3d'] = np.zeros([10,2], dtype=np.int64) 


    if rand_rgb_transformer is not None:
        output_dict['image'], _, _ = rand_rgb_transformer(output_dict['image'])
    # TO TENSORs
    output_dict['image'] = (output_dict['image'].astype(
        np.float32)/255.0).transpose([2, 0, 1])  # .mean(axis=0, keepdims=True) # 1,H,W
    output_dict['depth'] = output_dict['depth'].astype(np.float32)[
        None]  # 1,H,W

    return output_dict

def batch_grid_subsampling_kpconv(points, batches_len, features=None, labels=None, sampleDl=0.1, max_p=0, verbose=0, random_grid_orient=True):
    """
    CPP wrapper for a grid subsampling (method = barycenter for points and features)
    """
    if (features is None) and (labels is None):
        s_points, s_len = cpp_subsampling.subsample_batch(points,
                                                          batches_len,
                                                          sampleDl=sampleDl,
                                                          max_p=max_p,
                                                          verbose=verbose)
        return torch.from_numpy(s_points), torch.from_numpy(s_len)

    elif (labels is None):
        s_points, s_len, s_features = cpp_subsampling.subsample_batch(points,
                                                                      batches_len,
                                                                      features=features,
                                                                      sampleDl=sampleDl,
                                                                      max_p=max_p,
                                                                      verbose=verbose)
        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_features)

    elif (features is None):
        s_points, s_len, s_labels = cpp_subsampling.subsample_batch(points,
                                                                    batches_len,
                                                                    classes=labels,
                                                                    sampleDl=sampleDl,
                                                                    max_p=max_p,
                                                                    verbose=verbose)
        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_labels)

    else:
        s_points, s_len, s_features, s_labels = cpp_subsampling.subsample_batch(points,
                                                                                batches_len,
                                                                                features=features,
                                                                                classes=labels,
                                                                                sampleDl=sampleDl,
                                                                                max_p=max_p,
                                                                                verbose=verbose)
        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_features), torch.from_numpy(s_labels)


def batch_neighbors_kpconv(queries, supports, q_batches, s_batches, radius, max_neighbors):
    """
    Computes neighbors for a batch of queries and supports, apply radius search
    :param queries: (N1, 3) the query points
    :param supports: (N2, 3) the support points
    :param q_batches: (B) the list of lengths of batch elements in queries
    :param s_batches: (B)the list of lengths of batch elements in supports
    :param radius: float32
    :return: neighbors indices
    """

    neighbors = cpp_neighbors.batch_query(
        queries, supports, q_batches, s_batches, radius=radius)
    # print("neighbors.shape" , neighbors.shape, queries.shape,flush=True)
    if max_neighbors > 0:
        return torch.from_numpy(neighbors[:, :max_neighbors])
    else:
        return torch.from_numpy(neighbors)


def collate_fn_descriptor(list_data, config, neighborhood_limits):
    ret = merge_batch(list_data)

    batched_points = torch.from_numpy(ret['model_points'])
    batched_lengths = torch.from_numpy(ret['batched_model_point_lengths'])
    batched_features = torch.from_numpy(ret['model_point_features'])

    if ret.get('rand_model_points', None) is not None:
        batched_rand_points = torch.from_numpy(ret['rand_model_points'])
        batched_rand_lengths = torch.from_numpy(
            ret['batched_rand_model_point_lengths'])

        batched_points = torch.cat(
            [batched_points, batched_rand_points], dim=0)
        batched_lengths = torch.cat(
            [batched_lengths, batched_rand_lengths], dim=0)
        batched_features = torch.cat(
            [batched_features, batched_features], dim=0)

    # Starting radius of convolutions
    r_normal = config.first_subsampling_dl * config.conv_radius
    # Starting layer
    layer_blocks = []
    layer = 0

    # Lists of inputs
    input_points = []
    input_neighbors = []
    input_pools = []
    input_upsamples = []
    input_batches_len = []
    timer = Timer()
    for block_i, block in enumerate(config.architecture):
        # Stop when meeting a global pooling or upsampling
        if 'global' in block or 'upsample' in block:
            break

        # Get all blocks of the layer
        if not ('pool' in block or 'strided' in block):
            layer_blocks += [block]
            if block_i < len(config.architecture) - 1 and not ('upsample' in config.architecture[block_i + 1]):
                continue

        # Convolution neighbors indices
        # *****************************

        if layer_blocks:
            # Convolutions are done in this layer, compute the neighbors with the good radius
            if np.any(['deformable' in blck for blck in layer_blocks[:-1]]):
                r = r_normal * config.deform_radius / config.conv_radius
            else:
                r = r_normal
            conv_i = batch_neighbors_kpconv(
                batched_points, batched_points, batched_lengths, batched_lengths, r, neighborhood_limits[layer])

        else:
            # This layer only perform pooling, no neighbors required
            conv_i = torch.zeros((0, 1), dtype=torch.int64)

        # Pooling neighbors indices
        # *************************

        if 'pool' in block or 'strided' in block:

            # New subsampling length
            dl = 2 * r_normal / config.conv_radius

            # Subsampled points
            pool_p, pool_b = batch_grid_subsampling_kpconv(
                batched_points, batched_lengths, sampleDl=dl)

            # Radius of pooled neighbors
            if 'deformable' in block:
                r = r_normal * config.deform_radius / config.conv_radius
            else:
                r = r_normal

            # Subsample indices
            pool_i = batch_neighbors_kpconv(
                pool_p, batched_points, pool_b, batched_lengths, r, neighborhood_limits[layer])

            # Upsample indices (with the radius of the next layer to keep wanted density)
            up_i = batch_neighbors_kpconv(
                batched_points, pool_p, batched_lengths, pool_b, 2 * r, neighborhood_limits[layer])

        else:
            # No pooling in the end of this layer, no pooling indices required
            pool_i = torch.zeros((0, 1), dtype=torch.int64)
            pool_p = torch.zeros((0, 3), dtype=torch.float32)
            pool_b = torch.zeros((0,), dtype=torch.int64)
            up_i = torch.zeros((0, 1), dtype=torch.int64)

        # Updating input lists
        input_points += [batched_points.float()]
        input_neighbors += [conv_i.long()]
        input_pools += [pool_i.long()]
        input_upsamples += [up_i.long()]
        input_batches_len += [batched_lengths]

        # New points for next layer
        batched_points = pool_p
        batched_lengths = pool_b

        # Update radius and reset blocks
        r_normal *= 2
        layer += 1
        layer_blocks = []

    ###############
    # Return inputs
    ###############
    dict_inputs = {
        "idx": ret["idx"],
        'model_points': input_points,
        'visibility': torch.from_numpy(ret['visibility']),
        'neighbors': input_neighbors,
        'pools': input_pools,
        'upsamples': input_upsamples,
        'model_point_features': batched_features.float(),
        'stack_lengths': input_batches_len,
        'image': torch.from_numpy(ret['image']),
        'depth': torch.from_numpy(ret['depth']),
        'mask': torch.from_numpy(ret['mask']),
        'ren_mask': torch.from_numpy(ret['ren_mask']),
        'K': torch.from_numpy(ret['K']),
        'RT': torch.from_numpy(ret['RT']),
        'original_RT': torch.from_numpy(ret['original_RT']),
        'POSECNN_RT': torch.from_numpy(ret.get('POSECNN_RT', np.zeros_like(ret['RT']) ) ),
        'rand_RT': torch.from_numpy(ret.get('rand_RT', np.zeros_like(ret['RT']))),
        # "lifted_points": torch.from_numpy(ret['lifted_points']),
        "lifted_points": [torch.from_numpy(d) for d in ret['lifted_points'] ] ,
        # 'depth_coords2d': torch.from_numpy(ret['depth_coords2d']),
        'depth_coords2d': [torch.from_numpy(d) for d in ret['depth_coords2d']],
        "correspondences_2d3d": torch.from_numpy(ret['correspondences_2d3d']),
        "original_model_points": torch.from_numpy(ret['original_model_points']),
        "class_name": ret['class_name'],
        "3d_keypoint_inds": torch.from_numpy(ret['3d_keypoint_inds']),
        "3d_keypoints": torch.from_numpy(ret['3d_keypoints'] ) 
    }

    return dict_inputs


def collate_fn_descriptor_deepim(list_data, config, neighborhood_limits):
    ret = merge_batch(list_data)

    batched_points = torch.from_numpy(ret['model_points'])
    batched_lengths = torch.from_numpy(ret['batched_model_point_lengths'])
    batched_features = torch.from_numpy(ret['model_point_features'])
    

    if ret.get('rand_model_points', None) is not None:
        # torch.from_numpy(np.concatenate(batched_points_list, axis=0))
        batched_rand_points = torch.from_numpy(ret['rand_model_points'])
        # torch.from_numpy(np.concatenate(batched_points_list, axis=0))
        batched_rand_lengths = torch.from_numpy(
            ret['batched_rand_model_point_lengths'])

        batched_points = torch.cat(
            [batched_points, batched_rand_points], dim=0)
        batched_lengths = torch.cat(
            [batched_lengths, batched_rand_lengths], dim=0)
        batched_features = torch.cat(
            [batched_features, batched_features], dim=0)

    # Starting radius of convolutions
    r_normal = config.first_subsampling_dl * config.conv_radius
    # Starting layer
    layer_blocks = []
    layer = 0

    # Lists of inputs
    input_points = []
    input_neighbors = []
    input_pools = []
    input_upsamples = []
    input_batches_len = []
    timer = Timer()
    for block_i, block in enumerate(config.architecture):
        # timer.tic()

        # Stop when meeting a global pooling or upsampling
        if 'global' in block or 'upsample' in block:
            break

        # Get all blocks of the layer
        if not ('pool' in block or 'strided' in block):
            layer_blocks += [block]
            if block_i < len(config.architecture) - 1 and not ('upsample' in config.architecture[block_i + 1]):
                continue

        # Convolution neighbors indices
        # *****************************

        if layer_blocks:
            # Convolutions are done in this layer, compute the neighbors with the good radius
            if np.any(['deformable' in blck for blck in layer_blocks[:-1]]):
                r = r_normal * config.deform_radius / config.conv_radius
            else:
                r = r_normal
            conv_i = batch_neighbors_kpconv(
                batched_points, batched_points, batched_lengths, batched_lengths, r, neighborhood_limits[layer])

        else:
            # This layer only perform pooling, no neighbors required
            conv_i = torch.zeros((0, 1), dtype=torch.int64)

        # Pooling neighbors indices
        # *************************

        # If end of layer is a pooling operation
        if 'pool' in block or 'strided' in block:

            # New subsampling length
            dl = 2 * r_normal / config.conv_radius

            # Subsampled points
            pool_p, pool_b = batch_grid_subsampling_kpconv(
                batched_points, batched_lengths, sampleDl=dl)

            # Radius of pooled neighbors
            if 'deformable' in block:
                r = r_normal * config.deform_radius / config.conv_radius
            else:
                r = r_normal

            # Subsample indices
            pool_i = batch_neighbors_kpconv(
                pool_p, batched_points, pool_b, batched_lengths, r, neighborhood_limits[layer])

            # Upsample indices (with the radius of the next layer to keep wanted density)
            up_i = batch_neighbors_kpconv(
                batched_points, pool_p, batched_lengths, pool_b, 2 * r, neighborhood_limits[layer])

        else:
            # No pooling in the end of this layer, no pooling indices required
            pool_i = torch.zeros((0, 1), dtype=torch.int64)
            pool_p = torch.zeros((0, 3), dtype=torch.float32)
            pool_b = torch.zeros((0,), dtype=torch.int64)
            up_i = torch.zeros((0, 1), dtype=torch.int64)

        # Updating input lists
        input_points += [batched_points.float()]
        input_neighbors += [conv_i.long()]
        input_pools += [pool_i.long()]
        input_upsamples += [up_i.long()]
        input_batches_len += [batched_lengths]

        # New points for next layer
        batched_points = pool_p
        batched_lengths = pool_b

        # Update radius and reset blocks
        r_normal *= 2
        layer += 1
        layer_blocks = []

        # timer.toc()
    ###############
    # Return inputs
    ###############
    dict_inputs = {
        "idx": ret["idx"],
        'model_points': input_points,
        'visibility': torch.from_numpy(ret['visibility']),
        'neighbors': input_neighbors,
        'pools': input_pools,
        'upsamples': input_upsamples,
        'model_point_features': batched_features.float(),
        'stack_lengths': input_batches_len,
        'image': torch.from_numpy(ret['image']),
        'depth': torch.from_numpy(ret['depth']),
        "ren_mask": torch.from_numpy(ret['ren_mask']),
        'K': torch.from_numpy(ret['K']),
        'RT': torch.from_numpy(ret['RT']),
        'original_RT': torch.from_numpy(ret['original_RT']),
        'rendered_RT': torch.from_numpy(ret['rendered_RT']) if ret.get('rendered_RT', None) is not None else None ,
        'POSECNN_RT': torch.from_numpy(ret.get('POSECNN_RT', np.zeros_like(ret['RT']) ) ),
        # TODO
        'rand_RT': torch.from_numpy(ret.get('rand_RT', np.zeros_like(ret['RT']))),
        # "lifted_points": torch.from_numpy(ret['lifted_points']),
        "lifted_points": [torch.from_numpy(d) for d in ret['lifted_points'] ] ,
        # 'depth_coords2d': torch.from_numpy(ret['depth_coords2d']),
        'depth_coords2d': [torch.from_numpy(d) for d in ret['depth_coords2d']],
        "correspondences_2d3d": torch.from_numpy(ret['correspondences_2d3d']),
        "original_model_points": torch.from_numpy(ret['original_model_points']),
        "class_name": ret['class_name'],
    }

    return dict_inputs


def calibrate_neighbors(dataset, config, collate_fn, keep_ratio=0.8, samples_threshold=2000):
    timer = Timer()
    last_display = timer.total_time

    # From config parameter, compute higher bound of neighbors number in a neighborhood
    hist_n = int(np.ceil(4 / 3 * np.pi * (config.deform_radius + 1) ** 3))
    neighb_hists = np.zeros((config.num_layers, hist_n), dtype=np.int32)

    # Get histogram of neighborhood sizes i in 1 epoch max.
    for i in range(len(dataset)):
        timer.tic()

        batched_input = collate_fn(
            [dataset[i]], config, neighborhood_limits=[hist_n] * 5)
        # update histogram
        counts = [torch.sum(neighb_mat < neighb_mat.shape[0], dim=1).numpy()
                  for neighb_mat in batched_input['neighbors']]
        
        hists = [np.bincount(c, minlength=hist_n)[:hist_n] for c in counts]
        neighb_hists += np.vstack(hists)
        timer.toc()

        if timer.total_time - last_display > 0.1:
            last_display = timer.total_time
            print(f"Calib Neighbors {i:08d}: timings {timer.total_time:4.2f}s")

        if np.min(np.sum(neighb_hists, axis=1)) > samples_threshold:
            break

    cumsum = np.cumsum(neighb_hists.T, axis=0)
    percentiles = np.sum(cumsum < (keep_ratio * cumsum[hist_n - 1, :]), axis=0)

    neighborhood_limits = percentiles
    print('\n')

    return neighborhood_limits


def get_dataloader(dataset, kpconv_config, batch_size=1, num_workers=4, shuffle=True, sampler=None, neighborhood_limits=None):
    if neighborhood_limits is None:
        # neighborhood_limits = calibrate_neighbors(dataset, dataset.config, collate_fn=collate_fn_descriptor)
        neighborhood_limits = calibrate_neighbors(
            dataset, kpconv_config, collate_fn=collate_fn_descriptor)
    print("neighborhood:", neighborhood_limits)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        # https://discuss.pytorch.org/t/supplying-arguments-to-collate-fn/25754/4
        collate_fn=partial(collate_fn_descriptor, config=kpconv_config,
                           neighborhood_limits=neighborhood_limits),
        sampler=sampler,
        drop_last=False
    )
    return dataloader, neighborhood_limits

def get_dataloader_deepim(dataset, kpconv_config, batch_size=1, num_workers=4, shuffle=True, sampler=None, neighborhood_limits=None):
    if neighborhood_limits is None:
        neighborhood_limits = calibrate_neighbors(
            dataset, kpconv_config, collate_fn=collate_fn_descriptor_deepim)
    print("Neighborhood:", neighborhood_limits)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        # https://discuss.pytorch.org/t/supplying-arguments-to-collate-fn/25754/4
        collate_fn=partial(collate_fn_descriptor_deepim, config=kpconv_config,
                           neighborhood_limits=neighborhood_limits),
        sampler=sampler,
        drop_last=False
    )
    return dataloader, neighborhood_limits

if __name__ == '__main__':
    pass


================================================
FILE: data/transforms.py
================================================
import numpy as np
import random
import torch
import torchvision
from torchvision.transforms import functional as F
import cv2
from PIL import Image


class Compose(object):

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, kpts=None, mask=None):
        for t in self.transforms:
            img, kpts, mask = t(img, kpts, mask)
        return img, kpts, mask

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += "    {0}".format(t)
        format_string += "\n)"
        return format_string


class ToTensor(object):

    def __call__(self, img, kpts, mask):
        return np.asarray(img).astype(np.float32) / 255., kpts, mask


class Normalize(object):

    def __init__(self, mean, std, to_bgr=True):
        self.mean = mean
        self.std = std
        self.to_bgr = to_bgr

    def __call__(self, img, kpts, mask):
        img -= self.mean
        img /= self.std
        if self.to_bgr:
            img = img.transpose(2, 0, 1).astype(np.float32)
        return img, kpts, mask


class ColorJitter(object):

    def __init__(self,
                 brightness=None,
                 contrast=None,
                 saturation=None,
                 hue=None,
                 ):
        self.color_jitter = torchvision.transforms.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            hue=hue,)

    def __call__(self, image, kpts, mask):
        image = np.asarray(self.color_jitter(Image.fromarray(np.ascontiguousarray(image, np.uint8))))
        return image, kpts, mask


class RandomBlur(object):

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, kpts, mask):
        if random.random() < self.prob:
            sigma = np.random.choice([3, 5, 7, 9])
            image = cv2.GaussianBlur(image, (sigma, sigma), 0)
        return image, kpts, mask


def make_transforms(cfg, is_train):
    if is_train is True:
        transform = Compose(
            [
                RandomBlur(0.5),
                ColorJitter(0.1, 0.1, 0.05, 0.05),
                # ToTensor(),
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    else:
        transform = Compose(
            [
                # ToTensor(),
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    return transform


================================================
FILE: data/ycb/basic.py
================================================
import mmcv 
bop_ycb_idx2class={
        1: '002_master_chef_can', 
        2: '003_cracker_box',
        3: '004_sugar_box', 
        4: '005_tomato_soup_can',
        5: '006_mustard_bottle',
        6: '007_tuna_fish_can',
        7: '008_pudding_box', 
        8: '009_gelatin_box',
        9: '010_potted_meat_can', 
        10: '011_banana', 
        11: '019_pitcher_base', 
        12: '021_bleach_cleanser', 
        13: '024_bowl', 
        14: '025_mug', 
        15: '035_power_drill',
        16: '036_wood_block',
        17: '037_scissors', 
        18: '040_large_marker',
        19: '051_large_clamp',
        20: '052_extra_large_clamp',
        21: '061_foam_brick', 
    }
bop_ycb_class2idx = dict([[bop_ycb_idx2class[k],k ] for k in bop_ycb_idx2class.keys() ])




================================================
FILE: doc/prepare_data.md
================================================
# Data Preparation Tips
All the related data for data preparation can be downloaded [here](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155139432_link_cuhk_edu_hk/EoXnZ96Tuy9PpYlZCvDN8vUBPdP1lP-PWQWiZH2KtIQoaQ?e=lpE472). You could download them first and then follow the instructions below for data preparation. 



## Download Datasets 
First, the following dataset need to be downloaded and extracted to the folder *EXPDATA/* 

[LINEMOD](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EYFaYrk0kcdBgC6WMtLJqP0B9Ar0_Nff9qhI2Cs95qDbdA?e=yYxexC)

[LINEMOD_OCC_TEST](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EUKcRnwyy9RGu2ASwA3QDXsBnMRrFP-U4X4Eqq-g_MhmIQ?e=hv6H2s)

## Synthetic Data Generation

The preprocessed data following [DeepIM](https://github.com/liyi14/mx-DeepIM) and [PVNet](https://github.com/zju3dv/pvnet-rendering) can be downloaded from [LM6d_converted](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EYFaYrk0kcdBgC6WMtLJqP0B9Ar0_Nff9qhI2Cs95qDbdA) and [raw_data](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/ESSFXi_7qs1AgNmty7_9y4AB8ffFsGJWOC3ikgD5BIeXHQ?e=qOmvds). 
After downloading, you should put the downloaded files into the folder *EXPDATA/* (lying in the repository's root directory). 
To create occluded objects during training, we follow [PVNet](https://github.com/zju3dv/pvnet-rendering) to randomly create occlusions. 
You could run the following scripts to transform the data format for our dataloader. 
```
    bash scripts/run_dataformatter.sh
```
The command above will automatically save the formatted data into *EXPDATA/*. 

## Download the Object CAD Models
You also need to download the [object models](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EQScZuLrkPNPmN4eO3kePaUBjOe92EvbKb7kGJk2vKz-bA?e=8McAdh) and put the extracted folder *models* into *./EXPDATA/LM6d_converted/. 

## Download Background Images
[Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) need to be downloaded to folder *EXPDATA/*. These images will be necessary for the random background generation for training. 

## Download Initial Poses 
The initial poses estimated by PoseCNN and PVNet can be downloaded from [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EQh5y0M_zHVMnbVszjEviCUBNAX_22MFN26Msa48XlJ5MQ?e=rfhT7k). 
The initial pose folder also should be put into the folder  *EXPDATA/*

## Generate the Information Files
Run the following script to generate the info files, which is put into the folder *EXPDATA/data_info/*

```
bash scripts/run_datainfo_generation.sh
```


After the the data preparation, the expected directory structure should be 


```
./EXPDATA
    |──LM6d_converted 
    |        |──LM6d_refine 
    |        |──LM6d_refine_syn
    |        └──models
    |──LINEMOD
    |        └──fuse_formatted
    |──lmo
    |──VOCdevkit
    |──raw_data
    |──init_poses
    └──data_info
```



================================================
FILE: docker/Dockerfile
================================================
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04

RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN rm /etc/apt/sources.list.d/cuda.list
RUN rm /etc/apt/sources.list.d/nvidia-ml.list

# Dependencies for glvnd and X11.
RUN apt-get update 
RUN  apt-get install -y -qq --no-install-recommends \
    libglvnd0 \
    libgl1 \
    libglx0 \
    libegl1 \
    libxext6 \
    libx11-6 \
  && rm -rf /var/lib/apt/lists/*
# Env vars for the nvidia-container-runtime.
ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES graphics,utility,compute

#env vars for cuda
ENV CUDA_HOME /usr/local/cuda

#install miniconda
RUN apt-get update --fix-missing && \
    apt-get install -y wget bzip2 ca-certificates curl git && \
    apt-get clean && \
    rm -rf /var/lib/apt/lists/*

RUN wget --quiet https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O ~/miniconda.sh && \
    /bin/bash ~/miniconda.sh -b -p /opt/miniconda3 && \
    rm ~/miniconda.sh && \
    /opt/miniconda3/bin/conda clean -tipsy && \
    ln -s /opt/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
    echo ". /opt/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
    echo "conda activate base" >> ~/.bashrc && \
    echo "conda deactivate && conda activate py37" >> ~/.bashrc

#https://blog.csdn.net/Mao_Jonah/article/details/89502380
COPY freeze.yml freeze.yml
RUN /opt/miniconda3/bin/conda env create -n py37 -f freeze.yml

# WORKDIR /tmp/
# COPY config.jupyter.tar config.jupyter.tar
# RUN tar -xvf config.jupyter.tar -C /root/

#install apex
ENV TORCH_CUDA_ARCH_LIST "6.0 6.2 7.0 7.2"
# make sure we don't overwrite some existing directory called "apex"
WORKDIR /tmp/unique_for_apex
# uninstall Apex if present, twice to make absolutely sure :)
RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :
RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :
# SHA is something the user can touch to force recreation of this Docker layer,
# and therefore force cloning of the latest version of Apex
RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
WORKDIR /tmp/unique_for_apex/apex
RUN /opt/miniconda3/envs/py37/bin/pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
#install pytorch3d 
# RUN /opt/miniconda3/envs/py37/bin/pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py37_cu102_pyt171/download.html
# RUN /opt/miniconda3/envs/py37/bin/pip install "git+https://github.com/facebookresearch/pytorch3d.git"
# RUN /opt/miniconda3/bin/conda install pytorch3d==0.5.0 -c pytorch3d -n py37



#other pkgs
RUN apt-get update \
  && apt-get install -y -qq --no-install-recommends \
  cmake build-essential vim xvfb unzip tmux psmisc  \
  libx11-dev libassimp-dev \
  mesa-common-dev freeglut3-dev \
  rsync \
  && apt-get clean \
  && rm -rf /var/lib/apt/lists/*

#create some directories
RUN mkdir -p /home/RNNPose

EXPOSE 8887 8888 8889 10000 10001 10002 
WORKDIR /home/RNNPose



================================================
FILE: docker/freeze.yml
================================================
name: py37_tmp
channels:
  - pytorch
  - pytorch3d
  - open3d-admin
  - bottler
  - iopath
  - fvcore
  - conda-forge
  - defaults
dependencies:
  - pytorch3d=0.5.0
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=4.5=1_gnu
  - anyio=2.2.0=py37h06a4308_1
  - argon2-cffi=20.1.0=py37h27cfd23_1
  - async_generator=1.10=py37h28b3542_0
  - attrs=21.2.0=pyhd3eb1b0_0
  - babel=2.9.1=pyhd3eb1b0_0
  - backcall=0.2.0=pyhd3eb1b0_0
  - blas=1.0=mkl
  - bleach=3.3.0=pyhd3eb1b0_0
  - brotlipy=0.7.0=py37h27cfd23_1003
  - ca-certificates=2021.5.30=ha878542_0
  - certifi=2021.5.30=py37h89c1867_0
  - cffi=1.14.5=py37h261ae71_0
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cryptography=3.4.7=py37hd23ed53_0
  - cudatoolkit=10.2.89=h8f6ccaa_8
  - cycler=0.10.0=py37_0
  - dbus=1.13.18=hb2f20db_0
  - defusedxml=0.7.1=pyhd3eb1b0_0
  - entrypoints=0.3=py37_0
  - expat=2.4.1=h2531618_2
  - fontconfig=2.13.1=h6c09931_0
  - freetype=2.10.4=h5ab3b9f_0
  - fvcore=0.1.5.post20210825=py37
  - glib=2.68.2=h36276a3_0
  - gst-plugins-base=1.14.0=h8213a91_2
  - gstreamer=1.14.0=h28cd5cc_2
  - icu=58.2=he6710b0_3
  - idna=3.2=pyhd3eb1b0_0
  - importlib-metadata=3.10.0=py37h06a4308_0
  - importlib_metadata=3.10.0=hd3eb1b0_0
  - intel-openmp=2021.2.0=h06a4308_610
  - iopath=0.1.9=py37
  - ipykernel=5.3.4=py37h5ca1d4c_0
  - ipython=7.22.0=py37hb070fc8_0
  - ipython_genutils=0.2.0=pyhd3eb1b0_1
  - ipywidgets=7.6.3=pyhd3eb1b0_1
  - jedi=0.17.0=py37_0
  - jinja2=3.0.0=pyhd3eb1b0_0
  - joblib=1.0.1=pyhd3eb1b0_0
  - jpeg=9b=h024ee3a_2
  - json5=0.9.6=pyhd3eb1b0_0
  - jsonschema=3.2.0=py_2
  - jupyter=1.0.0=py37h89c1867_6
  - jupyter_client=6.1.12=pyhd3eb1b0_0
  - jupyter_console=6.4.0=pyhd8ed1ab_0
  - jupyter_core=4.7.1=py37h06a4308_0
  - jupyter_server=1.4.1=py37h06a4308_0
  - jupyterlab=3.0.16=pyhd8ed1ab_0
  - jupyterlab_pygments=0.1.2=py_0
  - jupyterlab_server=2.7.1=pyhd3eb1b0_0
  - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
  - kiwisolver=1.3.1=py37h2531618_0
  - kornia=0.5.3=pyhd8ed1ab_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.35.1=h7274673_9
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.3.0=h5101ec6_17
  - libgfortran-ng=7.5.0=ha8ba4b0_17
  - libgfortran4=7.5.0=ha8ba4b0_17
  - libgomp=9.3.0=h5101ec6_17
  - libpng=1.6.37=hbc83047_0
  - libsodium=1.0.18=h7b6447c_0
  - libstdcxx-ng=9.3.0=hd4cf53a_17
  - libtiff=4.2.0=h85742a9_0
  - libuuid=1.0.3=h1bed415_2
  - libuv=1.40.0=h7b6447c_0
  - libwebp-base=1.2.0=h27cfd23_0
  - libxcb=1.14=h7b6447c_0
  - libxml2=2.9.10=hb55368b_3
  - lz4-c=1.9.3=h2531618_0
  - markupsafe=2.0.1=py37h27cfd23_0
  - matplotlib=3.3.4=py37h06a4308_0
  - matplotlib-base=3.3.4=py37h62a2d02_0
  - mistune=0.8.4=py37h14c3975_1001
  - mkl=2021.2.0=h06a4308_296
  - mkl-service=2.3.0=py37h27cfd23_1
  - mkl_fft=1.3.0=py37h42c9631_2
  - mkl_random=1.2.1=py37ha9443f7_2
  - nbclassic=0.2.6=pyhd3eb1b0_0
  - nbclient=0.5.3=pyhd3eb1b0_0
  - nbconvert=6.0.7=py37_0
  - nbformat=5.1.3=pyhd3eb1b0_0
  - ncurses=6.2=he6710b0_1
  - nest-asyncio=1.5.1=pyhd3eb1b0_0
  - ninja=1.10.2=hff7bd54_1
  - notebook=6.4.0=py37h06a4308_0
  - numpy=1.20.2=py37h2d18471_0
  - numpy-base=1.20.2=py37hfae3a4d_0
  - nvidiacub=1.10.0=0
  - olefile=0.46=py37_0
  - open3d=0.13.0=py37_0
  - openssl=1.1.1k=h7f98852_0
  - packaging=20.9=pyhd3eb1b0_0
  - pandas=1.2.4=py37h2531618_0
  - pandoc=2.12=h06a4308_0
  - pandocfilters=1.4.3=py37h06a4308_1
  - parso=0.8.2=pyhd3eb1b0_0
  - pcre=8.44=he6710b0_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pickleshare=0.7.5=pyhd3eb1b0_1003
  - pillow=8.2.0=py37he98fc37_0
  - pip=21.1.2=py37h06a4308_0
  - plyfile=0.7.4=pyhd8ed1ab_0
  - portalocker=2.3.0=py37h06a4308_0
  - prometheus_client=0.11.0=pyhd3eb1b0_0
  - prompt-toolkit=3.0.17=pyh06a4308_0
  - prompt_toolkit=3.0.17=hd3eb1b0_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pycparser=2.20=py_2
  - pygments=2.9.0=pyhd3eb1b0_0
  - pyopenssl=20.0.1=pyhd3eb1b0_1
  - pyparsing=2.4.7=pyhd3eb1b0_0
  - pyqt=5.9.2=py37h05f1152_2
  - pyrsistent=0.17.3=py37h7b6447c_0
  - pysocks=1.7.1=py37_1
  - python=3.7.10=h12debd9_4
  - python-dateutil=2.8.1=pyhd3eb1b0_0
  - python_abi=3.7=1_cp37m
  - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
  - pytz=2021.1=pyhd3eb1b0_0
  - pyzmq=20.0.0=py37h2531618_1
  - qt=5.9.7=h5867ecd_1
  - qtconsole=5.1.1=pyhd8ed1ab_0
  - qtpy=1.10.0=pyhd8ed1ab_0
  - readline=8.1=h27cfd23_0
  - requests=2.26.0=pyhd3eb1b0_0
  - scikit-learn=0.24.2=py37ha9443f7_0
  - scipy=1.6.2=py37had2a1c9_1
  - send2trash=1.5.0=pyhd3eb1b0_1
  - setuptools=52.0.0=py37h06a4308_0
  - sip=4.19.8=py37hf484d3e_0
  - six=1.16.0=pyhd3eb1b0_0
  - sniffio=1.2.0=py37h06a4308_1
  - sqlite=3.35.4=hdfb4753_0
  - tabulate=0.8.9=py37h06a4308_0
  - terminado=0.9.4=py37h06a4308_0
  - testpath=0.4.4=pyhd3eb1b0_0
  - threadpoolctl=2.1.0=pyh5ca1d4c_0
  - tk=8.6.10=hbc83047_0
  - torchvision=0.8.2=py37_cu102
  - tornado=6.1=py37h27cfd23_0
  - traitlets=5.0.5=pyhd3eb1b0_0
  - typing_extensions=3.7.4.3=pyha847dfd_0
  - wcwidth=0.2.5=py_0
  - webencodings=0.5.1=py37_1
  - wheel=0.36.2=pyhd3eb1b0_0
  - widgetsnbextension=3.5.1=py37_0
  - xz=5.2.5=h7b6447c_0
  - yacs=0.1.6=py_0
  - yaml=0.2.5=h7b6447c_0
  - zeromq=4.3.4=h2531618_0
  - zipp=3.4.1=pyhd3eb1b0_0
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.4.9=haebb681_0
  - pip:
    - absl-py==0.13.0
    - addict==2.4.0
    - anykeystore==0.2
    - cachetools==4.2.2
    - cryptacular==1.5.5
    - cython==0.29.24
    - decorator==4.4.2
    - easydict==1.9
    - einops==0.3.0
    - fire==0.4.0
    - flow-vis==0.1
    - freetype-py==2.2.0
    - future==0.18.2
    - glumpy==1.2.0
    - google-auth==1.31.0
    - google-auth-oauthlib==0.4.4
    - greenlet==1.1.0
    - grpcio==1.38.0
    - hupper==1.10.3
    - imageio==2.9.0
    - llvmlite==0.36.0
    - loguru==0.5.3
    - markdown==3.3.4
    - networkx==2.5.1
    - numba==0.53.1
    - numpy-quaternion==2021.6.9.13.34.11
    - oauthlib==3.1.1
    - opencv-python==4.5.2.54
    - pastedeploy==2.1.1
    - pbkdf2==1.3
    - plaster==1.0
    - plaster-pastedeploy==0.7
    - protobuf==3.17.3
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pyassimp==4.1.3
    - pyglet==1.5.17
    - pyopengl==3.1.5
    - pyopengl-accelerate==3.1.5
    - pyramid==2.0
    - pyramid-mailer==0.15.1
    - python3-openid==3.2.0
    - pywavelets==1.1.1
    - pyyaml==5.4.1
    - repoze-sendmail==4.4.1
    - requests-oauthlib==1.3.0
    - rsa==4.7.2
    - scikit-image==0.18.1
    - sqlalchemy==1.4.18
    - tensorboard==2.5.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.0
    - tensorboardx==2.2
    - termcolor==1.1.0
    - tifffile==2021.6.14
    - tqdm==4.61.1
    - transaction==3.0.1
    - transforms3d==0.3.1
    - translationstring==1.4
    - triangle==20200424
    - urllib3==1.26.5
    - velruse==1.1.1
    - venusian==3.0.0
    - vispy==0.6.6
    - webob==1.8.7
    - werkzeug==2.0.1
    - zope-deprecation==4.4.0
    - zope-interface==5.4.0
    - mmcv 
prefix: /opt/miniconda3/envs/py37_tmp


================================================
FILE: geometry/__init__.py
================================================


================================================
FILE: geometry/cholesky.py
================================================
# import tensorflow as tf
import torch #as tf
import numpy as np
# from utils.einsum import einsum
from torch import einsum



class _cholesky_solve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, H, b):
        chol = torch.cholesky(H)
        xx = torch.cholesky_solve(b, chol)
        ctx.save_for_backward(chol, xx)

        return xx

    # see OptNet: https://arxiv.org/pdf/1703.00443.pdf
    @staticmethod
    def backward(ctx, dx):
        chol, xx = ctx.saved_tensors

        dz = torch.cholesky_solve(dx, chol)
        xs = torch.squeeze(xx,  -1)
        zs = torch.squeeze(dz, -1)
        dH = -einsum('...i,...j->...ij', xs, zs)

        return dH, dz
def cholesky_solve(H, b):
    return _cholesky_solve.apply(H,b)

def solve(H, b, max_update=1.0):
    """ Solves the linear system Hx = b, H > 0"""

    # small system, solve on cpu
    H = H.to(dtype=torch.float64) 
    b = b.to(dtype=torch.float64) 

    b = torch.unsqueeze(b, -1)
    x = cholesky_solve(H, b)

    # replaces nans and clip large updates
    bad_values = torch.isnan(x) 
    x = torch.where(bad_values, torch.zeros_like(x), x)
    x = torch.clamp(x, -max_update, max_update)

    x = torch.squeeze(x, -1)
    x = x.to(dtype=torch.float32) 
        
    return x



def __test__():
    import numpy as np 
    np.random.seed(0)
    M=np.random.uniform(size=(3,3))
    H=torch.tensor(M@M.transpose(-1,-2), requires_grad=True )

    b=torch.tensor(np.random.uniform(size=(3,) ), requires_grad=True )

    x= solve(H,b )

    x.backward(torch.ones_like(x) )


    print(f"H={H}, b={b}, x={x}, grad={H.grad, b.grad}")

if __name__=="__main__":
    __test__()


================================================
FILE: geometry/diff_render.py
================================================
import torch
import torch.nn as nn 
import torch.nn.functional as F 

import numpy 

from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    camera_position_from_spherical_angles, HardPhongShader, PointLights,FoVPerspectiveCameras, PerspectiveCameras, SoftPhongShader, Materials
) 
try:
    from pytorch3d.structures import Meshes, Textures
    use_textures = True
except:
    from pytorch3d.structures import Meshes
    from pytorch3d.renderer import TexturesVertex
    from pytorch3d.renderer import TexturesVertex as Textures

    use_textures = False

import pytorch3d.renderer.mesh.utils as utils
from pytorch3d.io import load_obj, load_ply, load_objs_as_meshes
from pytorch3d.renderer.mesh.rasterizer import Fragments

from plyfile import PlyData
from utils.furthest_point_sample import fragmentation_fps
import time


def rasterize(R, T, meshes, rasterizer, blur_radius=0):
    # It will automatically update the camera settings -> R, T in rasterizer.camera
    fragments = rasterizer(meshes, R=R, T=T)

    # Copy from pytorch3D source code, try if it is necessary to do gradient decent
    if blur_radius > 0.0:
        clipped_bary_coords = utils._clip_barycentric_coordinates(
            fragments.bary_coords
        )
        clipped_zbuf = utils._interpolate_zbuf(
            fragments.pix_to_face, clipped_bary_coords, meshes
        )
        fragments = Fragments(
            bary_coords=clipped_bary_coords,
            zbuf=clipped_zbuf,
            dists=fragments.dists,
            pix_to_face=fragments.pix_to_face,
        )
    return fragments

def set_bary_coords_to_nearest(bary_coords_):
    ori_shape = bary_coords_.shape
    exr = bary_coords_ * (bary_coords_ < 0)
    bary_coords_ = bary_coords_.view(-1, bary_coords_.shape[-1])
    arg_max_idx = bary_coords_.argmax(1)
    return torch.zeros_like(bary_coords_).scatter(1, arg_max_idx.unsqueeze(1), 1.0).view(*ori_shape) + exr

class MeshRendererWithDepth(nn.Module):
    def __init__(self, rasterizer, shader):
        super().__init__()
        self.rasterizer = rasterizer
        self.shader = shader

    def to(self, device):
        # Rasterizer and shader have submodules which are not of type nn.Module
        self.rasterizer.to(device)
        self.shader.to(device)
        return self

    def forward(self, meshes_world, **kwargs) -> torch.Tensor:
        fragments = self.rasterizer(meshes_world, **kwargs)
        images = self.shader(fragments, meshes_world, **kwargs)
        return images, fragments.zbuf

class DiffRender(nn.Module):
    def __init__(self, mesh_path, render_texture=False):
        super().__init__()

        # self.mesh = mesh
        if mesh_path.endswith('.ply'):
            verts, faces = load_ply(mesh_path)
            self.mesh = Meshes(verts=[verts], faces=[faces])
        elif mesh_path.endswith('.obj'):
            verts, faces,_ = load_obj(mesh_path)
            faces=faces.verts_idx
            self.mesh=load_objs_as_meshes([mesh_path])

        self.verts = verts
        self.faces = faces
        self.cam_opencv2pytch3d = torch.tensor(
                                [[-1,0,0,0],
                                [0,-1,0, 0],
                                [0,0, 1, 0],
                                [0,0, 0, 1]], dtype=torch.float32
                                )
        self.render_texture = render_texture

        #get patch infos
        self.pat_centers, self.pat_center_inds,  self.vert_frag_ids= fragmentation_fps(verts.detach().cpu().numpy(), 64)
        self.pat_centers = torch.from_numpy(self.pat_centers)
        self.pat_center_inds = torch.from_numpy(self.pat_center_inds)
        self.vert_frag_ids = torch.from_numpy(self.vert_frag_ids)[...,None] #Nx1




    def to(self, *args, **kwargs):
        if 'device' in kwargs.keys():
            device = kwargs['device']
        else:
            device = args[0]
        super().to(device)
        self.mesh = self.mesh.to(device)
        self.verts = self.verts.to(device)
        self.faces = self.faces.to(device)
        self.pat_centers = self.pat_centers.to(device)
        self.pat_center_inds = self.pat_center_inds.to(device)
        self.vert_frag_ids = self.vert_frag_ids.to(device)
        
        return self

    def get_patch_center_depths(self, T, K):
        #no need to pre-transform, as here we do not use pytorch3d rendering
        # T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        #render depths
        X_cam= (self.pat_centers@R+t) #BxKx3
        depth= X_cam[...,2:] #BxKx1
        x=X_cam@K.transpose(-1,-2)  #BxNx3
        x = x/x[...,-1:]
        img_coords= x[...,:2]


        return depth, img_coords 

    # Calculate interpolated maps -> [n, c, h, w]
    # face_memory.shape: [n_face, 3, c]
    @staticmethod
    def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear', return_depth=True):

        fragments = rasterize(R, t, meshes, rasterizer, blur_radius=blur_radius)

        # [n, h, w, 1, d]
        if mode == 'nearest':
            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, set_bary_coords_to_nearest(fragments.bary_coords), face_memory)
        else:
            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, face_memory)

        out_map = out_map.squeeze(dim=3)
        out_map = out_map.transpose(3, 2).transpose(2, 1)
        if return_depth:
            return out_map, fragments.zbuf.permute(0,3,1,2) # depth
        else:
            return out_map

    def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, lights=(1,1,-1) ):
        B=T.shape[0]

        device = T.device
        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2],  R=R, T=t, image_size=[render_image_size]*B, in_ndc=False, device=device)
        lights = PointLights(device=device, location=[lights])

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            bin_size=None, #0
            perspective_correct=True
        )
        materials = Materials(
            device=device,
            # specular_color=[[0.0, 1.0, 0.0]],
            shininess=0
        )
        renderer = MeshRendererWithDepth(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftPhongShader(
                device=device, 
                cameras=cameras,
                lights=lights, 
                blend_params=BlendParams(1e-4, 1e-4, (0, 0, 0))
            )
        )
        image,depth =renderer(self.mesh, lights=lights, materials=materials)

        return image.permute(0,3,1,2)[:,:3], depth.permute(0,3,1,2) # to BCHW

    def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):
        yy, xx = torch.meshgrid(torch.arange(render_image_size[0], device=T.device), torch.arange(render_image_size[1], device=T.device) )
        # xx = xx.to(dtype=torch.float32)
        # yy = yy.to(dtype=torch.float32)
        coords_grid = torch.stack( [ xx.to(dtype=torch.float32),  yy.to(dtype=torch.float32)], dim=-1 )

        #no need to pre-transform, as here we do not use pytorch3d rendering
        # T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        #render depths
        X_cam= (self.pat_centers@R+t)
        x=X_cam@K.transpose(-1,-2)  #BxNx3
        x = x/x[...,-1:]

        offset = x[...,None,None,:2] - coords_grid #BxNx1x1x2-HxWx2
        
        return offset.permute(0,1,4,2,3) #BxNx2xHxW

    def forward(self, vert_attribute, T, K, render_image_size, near=0.1, far=6, mode='bilinear') :
        """
        Args:
            vert_attribute: (N,C)
            T: (B,3,4) or (B,4,4)
            K: (B,3,3)
            render_image_size (tuple): (h,w)
            near (float, optional):  Defaults to 0.1.
            far (int, optional): Defaults to 6.
        """

        if vert_attribute is None:
            return self.render_mesh(T, K, render_image_size, near=0.1, far=6 )
        if self.render_texture:
            ren_tex=self.render_mesh(T, K, render_image_size, near=0.1, far=6 )


        B=T.shape[0]
        face_attribute = vert_attribute[self.faces.long()]

        device = T.device

        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]
        # t = -(R@T[...,:3,3:]).squeeze(-1)

        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            bin_size=None, #0
            perspective_correct=True
        )

        rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        )

        out_map, out_depth=self.forward_interpolate(R, t, self.mesh, face_attribute, rasterizer, blur_radius=0, mode=mode)
        
        if not self.render_texture:
            return out_map, out_depth
        else:
            return torch.cat([ren_tex[0], out_map ], dim=1), out_depth

    def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode='neareast'):
        """
        Args:
            T: (B,3,4) or (B,4,4)
            K: (B,3,3)
            render_image_size (tuple): (h,w)
            near (float, optional):  Defaults to 0.1.
            far (int, optional): Defaults to 6.
            mode: 'bilinear' or 'neareast'
        """

        B=T.shape[0]
        device = T.device

        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]
        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            bin_size=0
        )

        rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        )


        #render depths
        vert_depths= (self.verts@R+t).squeeze(0)[...,2:]
        face_depths = vert_depths[self.faces.long()]
        out_depth=self.forward_interpolate(R, t, self.mesh, face_depths, rasterizer, blur_radius=0, mode='nearest', return_depth=False)

        return out_depth


class DiffRendererWrapper(nn.Module):
    def __init__(self, obj_paths, device="cuda", render_texture=False ):
        super().__init__()

        self.renderers = []
        for obj_path in obj_paths:
            self.renderers.append( 
                DiffRender(obj_path, render_texture).to(device=device)
            )

        self.renderers=nn.ModuleList(self.renderers)
        self.cls2idx=None #could be updated outside

    def get_patch_center_depths(self, model_names, T, K):
        
        depths= []
        image_coords= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]
            depth, img_coord = self.renderers[model_idx].get_patch_center_depths(T[b:b+1], K )
            depths.append(depth)
            image_coords.append(img_coord)
        
        return torch.cat(depths, dim=0), torch.cat(image_coords, dim=0)


    def render_offset_map(self, model_names,  T, K, render_image_size, near=0.1, far=6):
        offsets= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            offset = self.renderers[model_idx].render_offset_map(T[b:b+1], K[b:b+1], render_image_size, near, far )
            offsets.append(offset)
        
        return torch.cat(offsets, dim=0)

    def render_pat_id(self, model_names,  T, K, render_image_size, near=0.1, far=6):

        pat_ids= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]
            
            pat_id,_ = self.renderers[model_idx].forward(self.renderers[model_idx].vert_frag_ids.float()+1,T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )
            pat_ids.append(pat_id-1) #+1 -1, set invalid parts as -1's  
        
        return torch.cat(pat_ids, dim=0)

    def render_depth(self, model_names,  T, K, render_image_size, near=0.1, far=6):

        depth_outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            depth = self.renderers[model_idx].render_depth( T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )
            depth_outputs.append(depth)
        
        return torch.cat(depth_outputs, dim=0)

    def forward(self, model_names,  vert_attribute, T, K, render_image_size, near=0.1, far=6):

        map_outputs= []
        depth_outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            feamap, depth= self.renderers[model_idx]( vert_attribute[b], T[b:b+1], K[b:b+1], render_image_size, near, far )

            map_outputs.append(feamap)
            depth_outputs.append(depth)
        return torch.cat(map_outputs, dim=0) , torch.cat(depth_outputs, dim=0)




================================================
FILE: geometry/diff_render_optim.py
================================================
## Speed optimized: sharing the rasterization among different rendering process

import torch
import torch.nn as nn 
import torch.nn.functional as F 

import numpy 

from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    camera_position_from_spherical_angles, HardPhongShader, PointLights,FoVPerspectiveCameras, PerspectiveCameras, SoftPhongShader, Materials 
) 
try:
    from pytorch3d.structures import Meshes, Textures
    use_textures = True
except:
    from pytorch3d.structures import Meshes
    from pytorch3d.renderer import TexturesVertex
    from pytorch3d.renderer import TexturesVertex as Textures

    use_textures = False

import pytorch3d.renderer.mesh.utils as utils
from pytorch3d.io import load_obj, load_ply, load_objs_as_meshes
from pytorch3d.renderer.mesh.rasterizer import Fragments

from plyfile import PlyData
from utils.furthest_point_sample import fragmentation_fps



def rasterize(R, T, meshes, rasterizer, blur_radius=0):
    # It will automatically update the camera settings -> R, T in rasterizer.camera
    fragments = rasterizer(meshes, R=R, T=T)

    # Copy from pytorch3D source code, try if it is necessary to do gradient decent
    if blur_radius > 0.0:
        clipped_bary_coords = utils._clip_barycentric_coordinates(
            fragments.bary_coords
        )
        clipped_zbuf = utils._interpolate_zbuf(
            fragments.pix_to_face, clipped_bary_coords, meshes
        )
        fragments = Fragments(
            bary_coords=clipped_bary_coords,
            zbuf=clipped_zbuf,
            dists=fragments.dists,
            pix_to_face=fragments.pix_to_face,
        )
    return fragments

def set_bary_coords_to_nearest(bary_coords_):
    ori_shape = bary_coords_.shape
    exr = bary_coords_ * (bary_coords_ < 0)
    bary_coords_ = bary_coords_.view(-1, bary_coords_.shape[-1])
    arg_max_idx = bary_coords_.argmax(1)
    return torch.zeros_like(bary_coords_).scatter(1, arg_max_idx.unsqueeze(1), 1.0).view(*ori_shape) + exr

class MeshRendererWithDepth(nn.Module):
    def __init__(self, rasterizer, shader):
        super().__init__()
        self.rasterizer = rasterizer
        self.shader = shader

    def to(self, device):
        # Rasterizer and shader have submodules which are not of type nn.Module
        self.rasterizer.to(device)
        self.shader.to(device)
        return self

    def forward(self, meshes_world, **kwargs) -> torch.Tensor:
        fragments = self.rasterizer(meshes_world, **kwargs)
        images = self.shader(fragments, meshes_world, **kwargs)
        return images, fragments.zbuf

class MeshRendererWithDepth_v2(nn.Module):
    def __init__(self, rasterizer, shader):
        super().__init__()
        self.rasterizer = rasterizer
        self.shader = shader

    # def to(self, device):
    def to(self, *args, **kwargs):
        if 'device' in kwargs.keys():
            device = kwargs['device']
        else:
            device = args[0]
        super().to(device)
        # Rasterizer and shader have submodules which are not of type nn.Module
        self.rasterizer.to(device)
        self.shader.to(device)
        return self

    def forward(self, meshes_world, **kwargs) -> torch.Tensor:
        if 'fragments' not in kwargs.keys() or kwargs['fragments'] is None: # sharing fragment results with others for speed, as the rasterizing process occupies most of time
            if 'fragments' in kwargs:
                del kwargs['fragments']
                
            fragments = self.rasterizer(meshes_world, **kwargs)
        else:
            fragments = kwargs['fragments']
            del kwargs['fragments']

        images = self.shader(fragments, meshes_world, **kwargs)
        return images, fragments.zbuf

class DiffRender(nn.Module):
    def __init__(self, mesh_path, render_texture=False):
        super().__init__()

        # self.mesh = mesh
        if mesh_path.endswith('.ply'):
            verts, faces = load_ply(mesh_path)
            self.mesh = Meshes(verts=[verts], faces=[faces])
        elif mesh_path.endswith('.obj'):
            verts, faces,_ = load_obj(mesh_path)
            # import pdb; pdb.set_trace()
            faces=faces.verts_idx
            self.mesh=load_objs_as_meshes([mesh_path])

        # self.mesh = Meshes(verts=verts, faces=faces, textures=None)
        self.verts = verts
        self.faces = faces
        # self.mesh = Meshes(verts=[verts], faces=[faces])
        # self.feature=feature
        self.cam_opencv2pytch3d = torch.tensor(
                                [[-1,0,0,0],
                                [0,-1,0, 0],
                                [0,0, 1, 0],
                                [0,0, 0, 1]], dtype=torch.float32
                                )
        self.render_texture = render_texture

        #get patch infos
        self.pat_centers, self.pat_center_inds,  self.vert_frag_ids= fragmentation_fps(verts.detach().cpu().numpy(), 64)
        self.pat_centers = torch.from_numpy(self.pat_centers)
        self.pat_center_inds = torch.from_numpy(self.pat_center_inds)
        self.vert_frag_ids = torch.from_numpy(self.vert_frag_ids)[...,None] #Nx1




    def to(self, *args, **kwargs):
        if 'device' in kwargs.keys():
            device = kwargs['device']
        else:
            device = args[0]
        super().to(device)
        # self.rasterizer.cameras = self.rasterizer.cameras.to(device)
        # self.face_memory = self.face_memory.to(device)
        self.mesh = self.mesh.to(device)
        self.verts = self.verts.to(device)
        self.faces = self.faces.to(device)
        self.pat_centers = self.pat_centers.to(device)
        self.pat_center_inds = self.pat_center_inds.to(device)
        self.vert_frag_ids = self.vert_frag_ids.to(device)

        
        # self.cam_opencv2pytch3d = self.cam_opencv2pytch3d.to(device=device)
        return self

    def get_patch_center_depths(self, T, K):

        #no need to pre-transform, as here we do not use pytorch3d rendering
        # T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        #render depths
        X_cam= (self.pat_centers@R+t) #BxKx3
        depth= X_cam[...,2:] #BxKx1
        x=X_cam@K.transpose(-1,-2)  #BxNx3
        x = x/x[...,-1:]
        img_coords= x[...,:2]


        return depth, img_coords 

    # Calculate interpolated maps -> [n, c, h, w]
    # face_memory.shape: [n_face, 3, c]
    @staticmethod
    def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear', return_depth=True):

        fragments = rasterize(R, t, meshes, rasterizer, blur_radius=blur_radius)

        # [n, h, w, 1, d]
        if mode == 'nearest':
            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, set_bary_coords_to_nearest(fragments.bary_coords), face_memory)
        else:
            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, face_memory)
        out_map = out_map.squeeze(dim=3)
        out_map = out_map.transpose(3, 2).transpose(2, 1)
        if return_depth:
            return out_map, fragments.zbuf.permute(0,3,1,2), fragments # depth
        else:
            return out_map, fragments

    def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, lights=(1,1,-1), fragments=None ):
        B=T.shape[0]
        # face_attribute = vert_attribute[self.faces.long()]

        device = T.device
        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2],  R=R, T=t, image_size=[render_image_size]*B, in_ndc=False, device=device)
        lights = PointLights(device=device, location=[lights])

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1, #5,
            bin_size=None, #0
            perspective_correct=True
        )
        materials = Materials(
            device=device,
            # specular_color=[[0.0, 1.0, 0.0]],
            shininess=0
        )
        # renderer = MeshRendererWithDepth(
        renderer = MeshRendererWithDepth_v2(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftPhongShader(
            # shader=SoftGouraudShader(
                device=device, 
                cameras=cameras,
                lights=lights, 
                blend_params=BlendParams(1e-4, 1e-4, (0, 0, 0))
            )
        )
        image,depth =renderer(self.mesh, lights=lights, materials=materials, fragments=fragments)

        return image.permute(0,3,1,2)[:,:3], depth.permute(0,3,1,2) # to BCHW

    def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):
        yy, xx = torch.meshgrid(torch.arange(render_image_size[0], device=T.device), torch.arange(render_image_size[1], device=T.device) )
        # xx = xx.to(dtype=torch.float32)
        # yy = yy.to(dtype=torch.float32)
        coords_grid = torch.stack( [ xx.to(dtype=torch.float32),  yy.to(dtype=torch.float32)], dim=-1 )

        #no need to pre-transform, as here we do not use pytorch3d rendering
        # T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        #render depths
        X_cam= (self.pat_centers@R+t)#.squeeze(0)[...,2:]
        x=X_cam@K.transpose(-1,-2)  #BxNx3
        x = x/x[...,-1:]

        offset = x[...,None,None,:2] - coords_grid #BxNx1x1x2-HxWx2
        
        return offset.permute(0,1,4,2,3) #BxNx2xHxW

    # def forward(self, face_attribute, T, K, render_image_size, near=0.1, far=6):
    def forward(self, vert_attribute, T, K, render_image_size, near=0.1, far=6, render_texture=None, mode='bilinear') :
        """
        Args:
            vert_attribute: (N,C)
            T: (B,3,4) or (B,4,4)
            K: (B,3,3)
            render_image_size (tuple): (h,w)
            near (float, optional):  Defaults to 0.1.
            far (int, optional): Defaults to 6.
        """

        # use default rendering settings 
        if render_texture is None:
            render_texture= self.render_texture 
            
        if vert_attribute is None:
            # only render the rgb image
            return self.render_mesh(T, K, render_image_size, near=0.1, far=6 )

        B=T.shape[0]
        face_attribute = vert_attribute[self.faces.long()]

        device = T.device

        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]
        # t = -(R@T[...,:3,3:]).squeeze(-1)
        
        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            bin_size=None, #0
            perspective_correct=True
        )

        rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        )

        # forward_interpolate(R, T, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear')
        out_map, out_depth, fragments=self.forward_interpolate(R, t, self.mesh, face_attribute, rasterizer, blur_radius=0, mode=mode)
        
        if not render_texture:
            return out_map, out_depth
        else:
            ren_tex=self.render_mesh(T, K, render_image_size, near=0.1, far=6, fragments=fragments  )

            #The first 3 channels contain the rendered textures
            return torch.cat([ren_tex[0], out_map ], dim=1), out_depth

    def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode='neareast'):
        """
        Args:
            T: (B,3,4) or (B,4,4)
            K: (B,3,3)
            render_image_size (tuple): (h,w)
            near (float, optional):  Defaults to 0.1.
            far (int, optional): Defaults to 6.
            mode: 'bilinear' or 'neareast'
        """

        B=T.shape[0]
        device = T.device

        T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]
        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), 
            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)

        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
            bin_size=0
        )

        rasterizer = MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        )


        #render depths
        vert_depths= (self.verts@R+t).squeeze(0)[...,2:]
        face_depths = vert_depths[self.faces.long()]
        out_depth, _ =self.forward_interpolate(R, t, self.mesh, face_depths, rasterizer, blur_radius=0, mode='nearest', return_depth=False)

        return out_depth

    def render_pointcloud(self, T, K, render_image_size, near=0.1, far=6):
        """
        Args:
            T: (B,3,4) or (B,4,4)
            K: (B,3,3)
            render_image_size (tuple): (h,w)
            near (float, optional):  Defaults to 0.1.
            far (int, optional): Defaults to 6.
            mode: 'bilinear' or 'neareast'
        """

        B=T.shape[0]
        device = T.device

        # T = self.cam_opencv2pytch3d.to(device=T.device)@T

        ## X_cam = X_world R + t
        R = T[...,:3,:3].transpose(-1,-2)
        t = T[...,:3,3]

        #render depths
        # vert_depths= (self.verts@R+t).squeeze(0)[...,2:]
        X_cam= (self.verts@R+t)#.squeeze(0)

        x=X_cam@K.transpose(-1,-2)  #BxNx3
        depth = x[...,-1]
        x = x/x[...,-1:]

        out = torch.zeros([1,1, *render_image_size], dtype=R.dtype, device=R.device)
        out[:, :, 
            torch.round(x[0, :, 1]).long().clamp(0, out.shape[2]-1),
            torch.round(x[0, :, 0]).long().clamp(0, out.shape[3]-1)] = depth 

        return out #1x1xHxW


class DiffRendererWrapper(nn.Module):
    def __init__(self, obj_paths, device="cuda", render_texture=False ):
        super().__init__()

        self.renderers = []
        for obj_path in obj_paths:
            self.renderers.append( 
                DiffRender(obj_path, render_texture).to(device=device)
            )

        self.renderers=nn.ModuleList(self.renderers)
        self.cls2idx=None #updated outside

    def get_patch_center_depths(self, model_names, T, K):
        
        depths= []
        image_coords= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]
            depth, img_coord = self.renderers[model_idx].get_patch_center_depths(T[b:b+1], K )
            depths.append(depth)
            image_coords.append(img_coord)
        
        return torch.cat(depths, dim=0), torch.cat(image_coords, dim=0)


    def render_offset_map(self, model_names,  T, K, render_image_size, near=0.1, far=6):
        offsets= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            offset = self.renderers[model_idx].render_offset_map(T[b:b+1], K[b:b+1], render_image_size, near, far )
            offsets.append(offset)
        
        return torch.cat(offsets, dim=0)

    def render_pat_id(self, model_names,  T, K, render_image_size, near=0.1, far=6):

        pat_ids= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]
            # face_pat_id = self.renderers[model_idx].vert_frag_ids[self.renderers[model_idx].faces.long()]
            
            pat_id,_ = self.renderers[model_idx].forward(self.renderers[model_idx].vert_frag_ids.float()+1,T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )
            pat_ids.append(pat_id-1) #+1 -1, set invalid parts as -1's  
        
        return torch.cat(pat_ids, dim=0)

    def render_depth(self, model_names,  T, K, render_image_size, near=0.1, far=6):
    
        depth_outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            depth = self.renderers[model_idx].render_depth( T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )
            depth_outputs.append(depth)
        
        return torch.cat(depth_outputs, dim=0)
    def render_mesh(self, model_names,  T, K, render_image_size, near=0.1, far=6):

        outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            img= self.renderers[model_idx].render_mesh( T[b:b+1], K[b:b+1], render_image_size, near, far, )[0]
            outputs.append(img)
        
        return torch.cat(outputs, dim=0)

    def render_pointcloud(self, model_names, T, K, render_image_size, near=0.1, far=6):
        outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]
            depth = self.renderers[model_idx].render_pointcloud( T[b:b+1], K[b:b+1], render_image_size, near, far )
            outputs.append(depth)
        
        return torch.cat(outputs, dim=0)

    def forward(self, model_names,  vert_attribute, T, K, render_image_size, near=0.1, far=6, render_tex=False):

        map_outputs= []
        depth_outputs= []
        for b,_ in enumerate(model_names):
            model_idx = self.cls2idx[model_names[b]]

            feamap, depth= self.renderers[model_idx]( vert_attribute[b], T[b:b+1], K[b:b+1], render_image_size, near, far, render_texture=render_tex )

            map_outputs.append(feamap)
            depth_outputs.append(depth)
        return torch.cat(map_outputs, dim=0) , torch.cat(depth_outputs, dim=0)




================================================
FILE: geometry/einsum.py
================================================
# import tensorflow as torch
import torch as torch

import numpy as np
import re
import string

def einsum(equation, *inputs):

    equation = equation.replace(' ', '')
    # input_shapes = [x.get_shape() for x in list(inputs)]
    input_shapes = [x.shape for x in list(inputs)]
    match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation)
    if not match:
        raise ValueError('Indices have incorrect format: %s' % equation)

    input_axis_labels = match.group(1).split(',')
    output_axis_labels = match.group(2)[2:] if match.group(2) else None

    if len(input_shapes) != len(input_axis_labels):
        raise ValueError('Got %d arguments for equation "%s", expecting %d' %
                        (len(input_shapes), equation, len(input_axis_labels)))

    # Resolve Ellipsis
    # Assign axes labels for unspecified dimensions in inputs. Labels taken
    # from unused labels. Follow numpy einsum broadcasting conventions for
    # tensors of different length and unlabeled output.
    ellipsis_axes = ''
    if '...' in equation:
        unused = ''.join([c for c in string.ascii_lowercase
                        if c not in ''.join(input_axis_labels)])
        for i, ax in enumerate(input_axis_labels):
            if '...' in ax:
                parts = ax.split('...')
                if len(parts) != 2:
                    raise ValueError('Unable to resolve ellipsis. Excess number found.')

                # n = input_shapes[i].ndims - len(''.join(parts))
                n = len(input_shapes[i]) - len(''.join(parts))
                if n < 0:
                    raise ValueError('Ellipses lengths do not match.')
                if len(unused) < n:
                    raise ValueError(
                        'Unable to resolve ellipsis, too many distinct labels.')
                replace_axes = unused[-n:] if n > 0 else ''
                input_axis_labels[i] = input_axis_labels[i].replace('...',
                                                                    replace_axes)
                if len(replace_axes) > len(ellipsis_axes):
                    ellipsis_axes = replace_axes
                    
    equation = equation.replace('...', ellipsis_axes)
    out = torch.einsum(equation, *inputs)
    # torch.add_to_collection("checkpoints", out)
    return out


================================================
FILE: geometry/intrinsics.py
================================================
import torch
import numpy as np
# from utils.einsum import einsum
from .einsum import einsum

def intrinsics_vec_to_matrix(kvec):
    fx, fy, cx, cy = torch.unbind(kvec, dim=-1)
    z = torch.zeros_like(fx)
    o = torch.ones_like(fx)

    K = torch.stack([fx, z, cx, z, fy, cy, z, z, o], dim=-1)
    K = torch.reshape(K, list(kvec.shape)[:-1] + [3,3])
    return K

def intrinsics_matrix_to_vec(kmat):
    fx = kmat[..., 0, 0]
    fy = kmat[..., 1, 1]
    cx = kmat[..., 0, 2]
    cy = kmat[..., 1, 2]
    return torch.stack([fx, fy, cx, cy], dim=-1)

def update_intrinsics(intrinsics, delta_focal):
    kvec = intrinsics_matrix_to_vec(intrinsics)
    fx, fy, cx, cy = torch.unstack(kvec, num=4, axis=-1)
    df = torch.squeeze(delta_focal, -1)

    # update the focal lengths
    fx = torch.exp(df) * fx
    fy = torch.exp(df) * fy

    kvec = torch.stack([fx, fy, cx, cy], axis=-1)
    kmat = intrinsics_vec_to_matrix(kvec)
    return kmat

def rescale_depth(depth, downscale=4):
    depth = depth[:,None]
    new_shape = [depth.shape[-2]//downscale, depth.shape[-1]//downscale]
    depth = torch.nn.functional.interpolate(depth, new_shape, mode='nearest')
    return torch.squeeze(depth, dim=1)

def rescale_depth_and_intrinsics(depth, intrinsics, downscale=4):
    sc = torch.tensor([1.0/downscale, 1.0/downscale, 1.0], dtype=torch.float32, device=depth.device)
    intrinsics = einsum('...ij,i->...ij', intrinsics, sc)
    depth = rescale_depth(depth, downscale=downscale)
    return depth, intrinsics

def rescale_depths_and_intrinsics(depth, intrinsics, downscale=4):
    batch, frames, height, width = [depth.shape[i] for i in range(4)]
    depth = torch.reshape(depth, [batch*frames, height, width])
    depth, intrinsics = rescale_depth_and_intrinsics(depth, intrinsics, downscale)
    depth = torch.reshape(depth,
        [batch, frames]+list(depth.shape)[1:])
    return depth, intrinsics


================================================
FILE: geometry/projective_ops.py
================================================
import numpy as np
import torch 

# from utils.einsum import einsum
from torch import einsum


# MIN_DEPTH = 0.1
MIN_DEPTH = 0.01

def normalize_coords_grid(coords):
    """ normalize the coordinates to [-1,1]

    Args:
        coords: BxKxHxWx2
    """
    coords=coords.clone()
    B,K,H,W,_ = coords.shape

    coords[...,0] = 2*coords[...,0]/(W-1)-1
    coords[...,1] = 2*coords[...,1]/(H-1)-1

    return coords

def coords_grid(ref, homogeneous=True):
    """ grid of pixel coordinates """
    shape = ref.shape

    yy, xx = torch.meshgrid(torch.arange(shape[-2], device=ref.device), torch.arange(shape[-1], device=ref.device) )

    xx = xx.to(dtype=torch.float32)
    yy = yy.to(dtype=torch.float32)

    if homogeneous:
        coords = torch.stack([xx, yy, torch.ones_like(xx)], dim=-1)
    else:
        coords = torch.stack([xx, yy], dim=-1)

    new_shape = [1]*len(shape[:-2]) +  list(shape[-2:]) + [-1]
    coords = torch.reshape(coords, new_shape)

    tile = list(shape[:-2])+ [1,1,1]
    coords = coords.repeat(tile)
    return coords # BxKxHxWx2


def extract_and_reshape_intrinsics(intrinsics, shape=None):
    """ Extracts (fx, fy, cx, cy) from intrinsics matrix """

    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]

    if shape is not None:
        batch = list(fx.shape[:1])
        fillr = [1]*len(shape[1:]) 
        k_shape = batch+fillr

        fx = torch.reshape(fx, k_shape)
        fy = torch.reshape(fy, k_shape)
        cx = torch.reshape(cx, k_shape)
        cy = torch.reshape(cy, k_shape)

    return (fx, fy, cx, cy)


def backproject(depth, intrinsics, jacobian=False, depth_coords=None):
    """ backproject depth map to point cloud """
    #depth_coords: (BxKxHxWx2)

    if depth_coords is None:
        coords = coords_grid(depth, homogeneous=True)
        x, y, _ = torch.unbind(coords, axis=-1)
    else:
        x, y =  torch.unbind(depth_coords, axis=-1)

    x_shape = x.shape 
    
    fx, fy, cx, cy = extract_and_reshape_intrinsics(intrinsics, x_shape) #Bx1x1x1

    Z = depth  #BxKxHxW
    X = Z * (x - cx) / fx
    Y = Z * (y - cy) / fy 
    points = torch.stack([X, Y, Z], axis=-1)

    if jacobian:
        o = torch.zeros_like(Z) # used to fill in zeros

        # jacobian w.r.t (fx, fy) , of shape BxKxHxWx4x1
        jacobian_intrinsics = torch.stack([
            torch.stack([-X / fx], dim=-1),
            torch.stack([-Y / fy], dim=-1),
            torch.stack([o], dim=-1),
            torch.stack([o], dim=-1)], axis=-2)

        return points, jacobian_intrinsics
    
    return points
    # return points, coords


def project(points, intrinsics, jacobian=False):
    
    """ project point cloud onto image """
    X, Y, Z = torch.unbind(points, axis=-1)
    Z = torch.clamp(Z, min=MIN_DEPTH)

    x_shape = X.shape
    fx, fy, cx, cy = extract_and_reshape_intrinsics(intrinsics, x_shape)

    x = fx * (X / Z) + cx
    y = fy * (Y / Z) + cy
    coords = torch.stack([x, y], axis=-1)

    if jacobian:
        o = torch.zeros_like(x) # used to fill in zeros
        zinv1 = torch.where(Z <= MIN_DEPTH+.01, torch.zeros_like(Z), 1.0 / Z)
        zinv2 = torch.where(Z <= MIN_DEPTH+.01, torch.zeros_like(Z), 1.0 / Z**2)

        # jacobian w.r.t (X, Y, Z)
        jacobian_points = torch.stack([
            torch.stack([fx * zinv1, o, -fx * X * zinv2], axis=-1),
            torch.stack([o, fy * zinv1, -fy * Y * zinv2], axis=-1)], axis=-2)

        # jacobian w.r.t (fx, fy)
        jacobian_intrinsics = torch.stack([
            torch.stack([X * zinv1], axis=-1),
            torch.stack([Y * zinv1], axis=-1),], axis=-2)

        return coords, (jacobian_points, jacobian_intrinsics)

    return coords


================================================
FILE: geometry/se3.py
================================================
"""
SO3 and SE3 operations, exponentials and logarithms adapted from Sophus
"""

import numpy as np
import torch
from .einsum import einsum


MIN_THETA = 1e-4

def matdotv(A,b):
    return torch.squeeze(torch.matmul(A, torch.expand_dims(b, -1)), -1)

def hat(a):
    a1, a2, a3 = torch.split(a, [1,1,1], dim=-1)
    zz = torch.zeros_like(a1)

    ax = torch.stack([
        torch.cat([zz,-a3,a2], dim=-1),
        torch.cat([a3,zz,-a1], dim=-1),
        torch.cat([-a2,a1,zz], dim=-1)
    ], dim=-2)

    return ax
    

### quaternion functions ###

def quaternion_rotate_point(q, pt, eq=None):
    if eq is None:
        w, vec = torch.split(q, [1, 3], axis=-1)
        uv = 2*matdotv(hat(vec), pt)
        return pt + w*uv + matdotv(hat(vec), uv)
    else:
        w, vec = torch.split(q, [1, 3], axis=-1)
        uv1 = 2*einsum(eq, hat(w*vec), pt)
        uv2 = 2*einsum(eq, hat(vec), pt)
        return pt + uv1 + einsum(eq, hat(vec), uv2)

def quaternion_rotate_matrix(q, mat, eq=None):
    if eq is None:
        w, vec = torch.split(q, [1, 3], axis=-1)
        uv = 2*torch.matmul(hat(vec), mat)
        return mat + w*uv + torch.matmul(hat(vec), uv)
    else:
        w, vec = torch.split(q, [1, 3], axis=-1)
        uv1 = 2*einsum(eq, hat(w*vec), mat)
        uv2 = 2*einsum(eq, hat(vec), mat)
        return mat + uv1 + einsum(eq, hat(vec), uv2)

def quaternion_inverse(q):
    return q * [1, -1, -1, -1]

def quaternion_multiply(a, b):
    aw, ax, ay, az = torch.split(a, [1,1,1,1], axis=-1)
    bw, bx, by, bz = torch.split(b, [1,1,1,1], axis=-1)
    
    q = torch.concat([
        aw * bw - ax * bx - ay * by - az * bz,
        aw * bx + ax * bw + ay * bz - az * by,
        aw * by + ay * bw + az * bx - ax * bz,
        aw * bz + az * bw + ax * by - ay * bx,
    ], axis=-1)

    return q

def quaternion_to_matrix(q):
    w, x, y, z = torch.split(q, [1,1,1,1], axis=-1)

    r11 = 1 - 2 * y**2 - 2 * z**2
    r12 = 2 * x * y - 2 * w * z
    r13 = 2 * z * x + 2 * w * y

    r21 = 2 * x * y + 2 * w * z
    r22 = 1 - 2 * x**2 - 2 * z**2
    r23 = 2 * y * z - 2 * w * x

    r31 = 2 * z * x - 2 * w * y
    r32 = 2 * y * z + 2 * w * x
    r33 = 1 - 2 * x**2 - 2 * y**2
    
    R = torch.stack([
        torch.concat([r11,r12,r13], axis=-1),
        torch.concat([r21,r22,r23], axis=-1),
        torch.concat([r31,r32,r33], axis=-1)
    ], axis=-2)

    return R

def rotation_matrix_to_quaternion(R):
    Rxx, Ryx, Rzx = R[...,0,0], R[...,0,1], R[...,0,2]
    Rxy, Ryy, Rzy = R[...,1,0], R[...,1,1], R[...,1,2]
    Rxz, Ryz, Rzz = R[...,2,0], R[...,2,1], R[...,2,2]

    zz = torch.zeros_like(Rxx)
    k1 = torch.stack([Rxx-Ryy-Rzz, zz, zz, zz], axis=-1)
    k2 = torch.stack([Ryx+Rxy, Ryy-Rxx-Rzz, zz, zz], axis=-1)
    k3 = torch.stack([Rzx+Rxz, Rzy+Ryz, Rzz-Rxx-Ryy,zz], axis=-1)
    k4 = torch.stack([Ryz-Rzy, Rzx-Rxz, Rxy-Ryx, Rxx+Ryy+Rzz], axis=-1)

    K = torch.stack([k1, k2, k3, k4], axis=-2)
    eigvals, eigvecs = torch.linalg.eigh(K)

    x, y, z, w = torch.split(eigvecs[...,-1], [1,1,1,1], axis=-1)
    qvec = torch.concat([w, x, y, z], axis=-1)
    qvec /= torch.sqrt(torch.reduce_sum(qvec**2, axis=-1, keepdims=True))

    return qvec * torch.sign(w)

def so3_expm_and_theta(omega):
    """ omega in \so3 """
    theta_sq = torch.reduce_sum(omega**2, axis=-1)
    theta = torch.sqrt(theta_sq)
    half_theta = 0.5*theta

    ### small ###
    imag_factor = torch.where(theta>MIN_THETA, 
        torch.sin(half_theta) / (theta + 1e-12), 
        0.5 - (1.0/48.0)*theta_sq + (1.0/3840.0)*theta_sq*theta_sq)

    real_factor = torch.where(theta>MIN_THETA, torch.cos(half_theta),
        1.0 - (1.0/8.0)*theta_sq + (1.0/384.0)*theta_sq*theta_sq)

    qw = real_factor
    qx = imag_factor * omega[...,0]
    qy = imag_factor * omega[...,1]
    qz = imag_factor * omega[...,2]

    quat = torch.stack([qw, qx, qy, qz], axis=-1)
    return quat, theta
        
def so3_logm_and_theta(so3):
    w, vec = torch.split(so3, [1,3], axis=-1)
    squared_n = torch.reduce_sum(vec**2, axis=-1, keepdims=True)
    n = torch.sqrt(squared_n)

    two_atan_nbyw_by_n = torch.where(n<MIN_THETA,
        2/w - w*squared_n / (w*w*w),
        2*torch.atan(n/w) / (n+1e-12))

    theta = two_atan_nbyw_by_n * n
    omega = two_atan_nbyw_by_n * vec
    return omega, theta

def se3_expm(xi):
    """ xi in \se3 """
    tau, omega = torch.split(xi, [3, 3], axis=-1)
    q, theta = so3_expm_and_theta(omega)


    theta = theta[...,torch.newaxis,torch.newaxis]
    theta = torch.tile(theta, 
        torch.concat([torch.ones_like(torch.shape(q)[:-1]), [3,3]], axis=-1))

    theta_sq = theta * theta
    Omega = hat(omega)
    Omega_sq = torch.matmul(Omega, Omega)

    Vs = torch.eye(3, batch_shape=torch.shape(xi)[:-1]) + \
         (1-torch.cos(theta)) / (theta_sq + 1e-12) * Omega + \
         (theta - torch.sin(theta)) / (theta_sq*theta + 1e-12) * Omega_sq

    V = torch.where(theta<MIN_THETA, quaternion_to_matrix(q), Vs)
    t = matdotv(V, tau)
    return q, t

def se3_logm(so3, t):
    omega, theta = so3_logm_and_theta(so3)
    Omega = hat(omega)
    Omega_sq = torch.matmul(Omega, Omega)

    theta = theta[...,torch.newaxis]
    theta = torch.tile(theta, 
        torch.concat([torch.ones_like(torch.shape(omega)[:-1]), [3,3]], axis=-1))
    half_theta = 0.5*theta

    Vinv_approx = torch.eye(3, batch_shape=torch.shape(omega)[:-1]) - \
        0.5*Omega + (1.0/12.0) * Omega_sq

    Vinv_exact = torch.eye(3, batch_shape=torch.shape(omega)[:-1]) - \
        0.5*Omega + (1-theta*torch.cos(half_theta) / \
        (2*torch.sin(half_theta)+1e-12)) / (theta*theta + 1e-12) * Omega_sq

    Vinv = torch.where(theta<MIN_THETA, Vinv_approx, Vinv_exact)
    tau = matdotv(Vinv, t)

    upsilon = torch.concat([tau, omega], axis=-1)
    return upsilon


### matrix functions ###

def se3_matrix_inverse(G):
    """ Invert SE3 matrix """
    inp_shape = G.shape 
    G = torch.reshape(G, [-1, 4, 4])

    R, t = G[:, :3, :3], G[:, :3, 3:]
    R = R.permute(0, 2, 1)
    t = -torch.matmul(R, t)

    filler = torch.tensor([0.0, 0.0, 0.0, 1.0], device=G.device, dtype=G.dtype) 
    filler = torch.reshape(filler, [1, 1, 4])
    filler = filler.repeat([G.shape[0], 1, 1]) 

    Ginv = torch.cat([R, t], dim=-1)
    Ginv = torch.cat([Ginv, filler], dim=-2)
    return torch.reshape(Ginv, inp_shape)


def _se3_matrix_expm_grad(grad):
    grad_upsilon_omega = torch.stack([
        grad[..., 0, 3],
        grad[..., 1, 3],
        grad[..., 2, 3],
        grad[..., 2, 1] - grad[..., 1, 2],
        grad[..., 0, 2] - grad[..., 2, 0],
        grad[..., 1, 0] - grad[..., 0, 1]
    ], axis=-1)

    return grad_upsilon_omega

def _se3_matrix_expm_shape(op):
    return [op.inputs[0].get_shape().as_list()[:-1] + [4, 4]]


def _se3_matrix_expm(upsilon_omega):
    """ se3 matrix exponential se(3) -> SE(3), works for arbitrary batch dimensions
    - Note: gradient is overridden with _se3_matrix_expm_grad, which approximates 
    gradient for small upsilon_omega
    """

    eps=1e-12
    inp_shape = upsilon_omega.shape 
    out_shape = list(inp_shape)[:-1]+[4,4] 

    upsilon_omega = torch.reshape(upsilon_omega, [-1, 6])
    batch = upsilon_omega.shape[0]
    v, w = torch.split(upsilon_omega, [3, 3], dim=-1)

    theta_sq = torch.sum(w**2, dim=1 )
    theta_sq = torch.reshape(theta_sq, [-1, 1, 1])

    theta = torch.sqrt(theta_sq)
    theta_po4 = theta_sq * theta_sq

    wx = hat(w)
    wx_sq = torch.matmul(wx, wx)
    I = torch.eye(3, dtype=upsilon_omega.dtype, device=upsilon_omega.device).repeat([batch,1,1])

    ### taylor approximations ###
    R1 =  I + (1.0 - (1.0/6.0)*theta_sq + (1.0/120.0)*theta_po4) * wx + \
        (0.5 - (1.0/12.0)*theta_sq + (1.0/720.0)*theta_po4) * wx_sq
    
    V1 = I + (0.5 - (1.0/24.0)*theta_sq + (1.0/720.0)*theta_po4)*wx + \
        ((1.0/6.0) - (1.0/120.0)*theta_sq + (1.0/5040.0)*theta_po4)*wx_sq

    ### exact values ###
    R2 = I + (torch.sin(theta) / (theta+eps)) * wx +\
        ((1 - torch.cos(theta)) / (theta_sq+eps)) * wx_sq

    V2 = I + ((1 - torch.cos(theta)) / (theta_sq + eps)) * wx + \
        ((theta - torch.sin(theta))/(theta_sq*theta + eps)) * wx_sq

    # print(theta.shape, R1.shape, R2.shape, ">>>", flush=True)
    # R = torch.where(theta[:, 0, 0]<MIN_THETA, R1, R2)
    # V = torch.where(theta[:, 0, 0]<MIN_THETA, V1, V2)
    R = torch.where(theta<MIN_THETA, R1, R2)
    V = torch.where(theta<MIN_THETA, V1, V2)

    t = torch.matmul(V, v[...,None]) 

    fill = torch.tensor([0, 0, 0, 1], dtype=torch.float32, device=R.device)
    fill = torch.reshape(fill, [1, 1, 4])
    fill = fill.repeat([batch, 1, 1])

    G = torch.cat([R, t], dim=2)
    G = torch.cat([G, fill], dim=1)
    G = torch.reshape(G, out_shape)
    return G


class SE3_Matrix_Expm(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, upsilon_omega):
        G=_se3_matrix_expm(upsilon_omega)
        # ctx.save_for_backward(G)
        return G 

    @staticmethod
    def backward(ctx, grad_output):
        # result, = ctx.saved_tensors
        
        return _se3_matrix_expm_grad(grad_output)


def se3_matrix_expm(upsilon_omega):
    return SE3_Matrix_Expm.apply(upsilon_omega)


def se3_matrix_increment(G, upsilon_omega):
    """ Left increment of rigid body transformation: G = expm(xi) G"""
    dG = se3_matrix_expm(upsilon_omega)
    return torch.matmul(dG, G)

================================================
FILE: geometry/transformation.py
================================================
import torch  
import numpy as np

# from core.config import cfg
from config.default import get_cfg
from .se3 import *
from .intrinsics import *
from . import projective_ops as pops
from . import cholesky

from .einsum import einsum

cholesky_solve = cholesky.solve


MIN_DEPTH = 0.1
MAX_RESIDUAL = 250.0

# can use both matrix or quaternions to represent rotations
DEFAULT_INTERNAL = 'matrix'


def clip_dangerous_gradients(x):
    return x


def jac_local_perturb(pt, fill=False):

    X, Y, Z = torch.split(pt,[1,1,1], dim=-1)  # torch.split(pt, [1, 1, 1], axis=-1)
    o, i = torch.zeros_like(X), torch.ones_like(X)
    if fill:
        j1 = torch.cat([i,  o,  o, o], dim=-1)
        j2 = torch.cat([o,  i,  o, o], dim=-1)
        j3 = torch.cat([o,  o,  i, o], dim=-1)
        j4 = torch.cat([o, -Z,  Y, o], dim=-1)
        j5 = torch.cat([Z,  o, -X, o], dim=-1)
        j6 = torch.cat([-Y,  X,  o, o],dim=-1)
    else:
        j1 = torch.cat([i,  o,  o], dim=-1)
        j2 = torch.cat([o,  i,  o], dim=-1)
        j3 = torch.cat([o,  o,  i], dim=-1)
        j4 = torch.cat([o, -Z,  Y], dim=-1)
        j5 = torch.cat([Z,  o, -X], dim=-1)
        j6 = torch.cat([-Y,  X,  o],dim=-1)
    jac = torch.stack([j1, j2, j3, j4, j5, j6], dim=-1)
    return jac


def cond_transform(cond, T1, T2):
    """ Return T1 if cond, else T2 """

    if T1.internal == 'matrix':
        mat = torch.cond(cond, lambda: T1.matrix(), lambda: T2.matrix())
        T = T1.__class__(matrix=mat, internal=T1.internal)

    elif T1.internal == 'quaternion':
        so3 = torch.cond(cond, lambda: T1.so3, lambda: T2.so3)
        translation = torch.cond(cond, lambda: T1.translation,
                              lambda: T2.translation)
        T = T1.__class__(so3=so3, translation=translation,
                         internal=T1.internal)
    return T


class SE3:
    def __init__(self, upsilon=None, matrix=None, so3=None, translation=None, eq=None, internal=DEFAULT_INTERNAL):
        self.eq = eq
        self.internal = internal

        if internal == 'matrix':
            if upsilon is not None:
                self.G = se3_matrix_expm(upsilon)
            elif matrix is not None:
                self.G = matrix
        else:
            raise NotImplementedError 

    def __call__(self, pt, jacobian=False):
        """ Transform set of points """

        if self.internal == 'matrix':

            pt = torch.cat([pt, torch.ones_like(pt[..., :1])],
                        dim=-1)  # convert to homogenous
            pt = einsum(self.eq, self.G[..., :3, :], pt)
        else:
            raise NotImplementedError

        if jacobian:
            jacobian = jac_local_perturb(pt)
            return pt, jacobian

        return pt

    def __mul__(self, other):
        if self.internal == 'matrix':
            G = torch.matmul(self.G, other.G)
            return self.__class__(matrix=G, internal=self.internal)
        else:
            raise NotImplementedError

    def identity_(self):
        if self.internal == 'matrix':
            shape=self.G.shape
            self.G=torch.eye(4, device=self.G.device).repeat([*shape[:-2],1,1])
        else:
            raise NotImplementedError


    def increment(self, upsilon):
        if self.internal == 'matrix':
            G = se3_matrix_increment(self.G, upsilon)
            return self.__class__(matrix=G, internal=self.internal)
        else:
            raise NotImplementedError

    def concat(self, other, axis=0):
        if self.internal == 'matrix':
            G = torch.concat([self.G, other.G], axis=axis)
        else:
            raise NotImplementedError


    def copy(self, stop_gradients=False):

        if self.internal == 'matrix':
            if stop_gradients:
                # return self.__class__(matrix=torch.stop_gradient(self.G), internal=self.internal)
                return self.__class__(matrix=self.G.detach(), internal=self.internal)
            else:
                return self.__class__(matrix=self.G, internal=self.internal)

        else:
            raise NotImplementedError

    def to_vec(self):
        return torch.concat([self.so3, self.translation], axis=-1)

    def inv(self):
        if self.internal == 'matrix':
            Ginv = se3_matrix_inverse(self.matrix())
            return self.__class__(matrix=Ginv, internal=self.internal)
        else:
            raise NotImplementedError

    def adj(self):
        if self.internal == 'matrix':
            R = self.G[..., :3, :3]
            t = self.G[..., :3, 3]
            A11 = R
            A12 = torch.matmul(hat(t), R)
            A21 = torch.zeros_like(A11)
            A22 = R
        else:
            raise NotImplementedError


        Ax = torch.concat([
            torch.concat([A11, A12], axis=-1),
            torch.concat([A21, A22], axis=-1)
        ], axis=-2)

        return Ax

    def logm(self):
        return se3_logm(self.so3, self.translation)

    def shape(self):
        # return torch.shape(self.so3)[:-1]
        if self.internal == 'matrix':
            my_shape = self.G.shape  # torch.shape(self.G)
        else:
            raise NotImplementedError

        return (my_shape[0], my_shape[1])

    def matrix(self, fill=True):
        if self.internal == 'matrix':
            return self.G
        else:
            raise NotImplementedError
       

    def transform(self, depth, intrinsics, valid_mask=False, return3d=False):
        
        # pt = pops.backproject(depth, intrinsics)
        pt = pops.backproject(depth, intrinsics)
        pt_new = self.__call__(pt)
        coords = pops.project(pt_new, intrinsics)
        if return3d:
            return coords, pt_new
        if valid_mask:
            vmask = (pt[..., -1] > MIN_DEPTH) & (pt_new[..., -1] > MIN_DEPTH)
            # vmask = torch.cast(vmask, torch.float32)[..., torch.newaxis]
            # vmask = vmask.to(dtype=torch.float32)[..., None, :,:] #BxKx1xHxW
            vmask = vmask.to(dtype=torch.float32)[..., :, :, None]  # BxKx1xHxW
            return coords, vmask
        return coords

    def induced_flow(self, depth, intrinsics, valid_mask=False):
        coords0 = pops.coords_grid(depth, homogeneous=False)

        if valid_mask:
            coords1, vmask = self.transform(
                depth, intrinsics, valid_mask=valid_mask)
            return coords1 - coords0, vmask
        coords1 = self.transform(depth, intrinsics, valid_mask=valid_mask)
        return coords1 - coords0

    def depth_change(self, depth, intrinsics):
        pt = pops.backproject(depth, intrinsics)
        pt_new = self.__call__(pt)
        return pt_new[..., -1] - pt[..., -1]
    
    def identity(self):
        """ Push identity transformation to start of collection """
        batch, frames = self.shape()
        if self.internal == 'matrix':
            # I = torch.eye(4, batch_shape=[batch, 1])
            I = torch.eye(4, dtype=self.G.dtype, device=self.G.device).repeat(
                [batch, 1, 1, 1])
            # return self.__class__(matrix=I, internal=self.internal, eq=self.eq)
            return self.__class__(matrix=I, internal=self.internal, eq=self.eq)
        else:
            raise NotImplementedError




class SE3Sequence(SE3):
    """ Stores collection of SE3 objects """

    def __init__(self, upsilon=None, matrix=None, so3=None, translation=None, eq= "aijk,ai...k->ai...j",internal=DEFAULT_INTERNAL):
        super().__init__(
            upsilon, matrix, so3, translation, internal=internal, eq=eq)

        # self.eq = "aijk,ai...k->ai...j"
    def __call__(self, pt, inds=None, jacobian=False):
        if self.internal == 'matrix':
            return super().__call__(pt, jacobian=jacobian)
        else:
            raise NotImplementedError


    def gather(self, inds):
        if self.internal == 'matrix':
            G = torch.index_select(self.G, index=inds, dim=1)
            return SE3Sequence(matrix=G, internal=self.internal)
        else:
            raise NotImplementedError

    # def append_identity(self):
    #     """ Push identity transformation to start of collection """
    #     batch, frames = self.shape()
    #     if self.internal == 'matrix':
    #         # I = torch.eye(4, batch_shape=[batch, 1])
    #         I = torch.eye(4, dtype=self.G.dtype, device=self.G.device).repeat(
    #             [batch, 1, 1, 1])

    #         G = torch.cat([I, self.G], dim=1)
    #         return SE3Sequence(matrix=G, internal=self.internal)
    #     else:
    #         raise NotImplementedError

    def reprojction_optim(self,
                       target,
                       weight,
                       depth,
                       intrinsics,
                       num_iters=2,
                       depth_img_coords=None
                       ):

        target = clip_dangerous_gradients(target).to(dtype=torch.float64)
        weight = clip_dangerous_gradients(weight).to(dtype=torch.float64)

        X0 = pops.backproject(depth, intrinsics, depth_coords=depth_img_coords)
        w = weight[..., None] 

        lm_lmbda = get_cfg("LM").LM_LMBDA
        ep_lmbda = get_cfg("LM").EP_LMBDA

        T = self.copy(stop_gradients=False)
        for i in range(num_iters):
            ### compute the jacobians of the transformation ###
            X1, jtran = T(X0, jacobian=True)
            x1, (jproj, jkvec) = pops.project(X1, intrinsics, jacobian=True)

            v = (X0[..., -1] > MIN_DEPTH) & (X1[..., -1] > MIN_DEPTH)
            # v = v.to(dtype=torch.float32)[..., None, None]
            v = v.to(dtype=torch.float64)[..., None, None]

            ### weighted gauss-newton update ###
            J = einsum('...ij,...jk->...ik', jproj.to(dtype=torch.float64), jtran.to(dtype=torch.float64 ))  

            H = einsum('ai...j,ai...k->aijk', v*w*J, J)
            b = einsum('ai...j,ai...->aij', v*w*J, target-x1)

            ### add dampening and apply increment ###
            H += ep_lmbda*torch.eye(6, dtype=H.dtype, device=H.device) + lm_lmbda*H*torch.eye(6,dtype=H.dtype, device=H.device)
            try:
                delta_upsilon = cholesky_solve(H, b)
            except:
                # print(w.shape,v.shape, w.mean(), v.mean(),H,b, '!!!!')
                raise
            T = T.increment(delta_upsilon)

        # update
        if self.internal == 'matrix':
            self.G = T.matrix()
            T = SE3Sequence(
                matrix=T.matrix(), internal=self.internal)
        else:
            raise NotImplementedError

        return T


    def transform(self, depth, intrinsics, valid_mask=False, return3d=False):
        return super().transform(depth, intrinsics, valid_mask, return3d)


================================================
FILE: model/CFNet.py
================================================


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from thirdparty.raft.update import BasicUpdateBlock
from thirdparty.raft.extractor import BasicEncoder
from thirdparty.raft.corr import CorrBlock, AlternateCorrBlock
from thirdparty.raft.utils.utils import bilinear_sampler, coords_grid, upflow

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass
        def __enter__(self):
            pass
        def __exit__(self, *args):
            pass


class ImageFeaEncoder(nn.Module):
    def __init__(self, input_dim=3, output_dim=256):
        super().__init__()
        self.fnet = BasicEncoder(output_dim=output_dim, norm_fn='instance', dropout=False, input_dim=input_dim)        

        if 1:#self.args.pretrained_model is not None:
            print("Loading the weights of RAFT...")
            import os             
            self.load_state_dict(
                #  torch.load(self.args.pretrained_model, map_location='cpu'), strict=False
                 torch.load( f"{os.path.dirname(os.path.abspath(__file__)) }/../weights/img_fea_enc.pth", map_location='cpu'), strict=True
            )
        else:
            print("ImageFeaEncoder will be trained from scratch...")

    def forward(self, image1, image2):
        image1 = 2 * (image1 / 255.0) - 1.0
        image2 = 2 * (image2 / 255.0) - 1.0

        image1 = image1.contiguous()
        image2 = image2.contiguous()
        with autocast(enabled=True):
            fmap1, fmap2 = self.fnet([image1, image2])
        return fmap1, fmap2


class GRU_CFUpdator(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        input_dim =  args.get("input_dim", 3)

        self.hidden_dim = hdim = 128
        self.context_dim = cdim = 128
        args.corr_levels = 4
        args.corr_radius = 4

        if 'alternate_corr' not in self.args:
            self.args.alternate_corr = False

        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)

        if self.args.pretrained_model is not None:
            print("Loading the weights of RAFT...")
            import os             
            self.load_state_dict(
                #  torch.load(self.args.pretrained_model, map_location='cpu'), strict=False
                 torch.load( f"{os.path.dirname(os.path.abspath(__file__)) }/../weights/gru_update.pth", map_location='cpu'), strict=True
            )
        else:
            print("GRU_CFUpdator will be trained from scratch...")


    
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, img, downsample_rate=8):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape
        coords0 = coords_grid(N, H//downsample_rate, W//downsample_rate).to(img.device)
        coords1 = coords_grid(N, H//downsample_rate, W//downsample_rate).to(img.device)

        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

    def upsample_flow(self, flow, mask, upsample_scale=8):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, upsample_scale, upsample_scale, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(upsample_scale * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, 2, upsample_scale*H, upsample_scale*W)


    def forward(self, fmap1, fmap2, iters=1, flow_init=None, upsample=True, test_mode=False, context_fea=None, update_corr_fn=True):
        """ Estimate optical flow between pair of frames """

        hdim = self.hidden_dim
        cdim = self.context_dim

        if update_corr_fn: # need carful handling outside
            # run the feature network
            self.fmap1 = fmap1.float()
            self.fmap2 = fmap2.float()
            if self.args.alternate_corr:
                self.corr_fn = AlternateCorrBlock(self.fmap1, self.fmap2, radius=self.args.corr_radius)
            else:
                self.corr_fn = CorrBlock(self.fmap1, self.fmap2, radius=self.args.corr_radius)

        if update_corr_fn: 
            # run the context network
            with autocast(enabled=self.args.mixed_precision):
                assert context_fea is not None
                ds = context_fea.shape[-1]//self.fmap1.shape[-1]
                cnet = F.interpolate(context_fea, scale_factor=1/ds, mode='bilinear', align_corners=True)

                self.net, self.inp = torch.split(cnet, [hdim, cdim], dim=1)
                self.net = torch.tanh(self.net)
                self.inp = torch.relu(self.inp)

        # coords0, coords1 = self.initialize_flow(image1)
        coords0, coords1 = self.initialize_flow(flow_init)

        if flow_init is not None:
            ds = flow_init.shape[-1]//coords0.shape[-1]
            if ds !=1:
                flow_init /=ds
                flow_init = F.interpolate(flow_init, scale_factor=1/ds, mode='bilinear', align_corners=True)

            coords1 = coords1 + flow_init

        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach()
            corr = self.corr_fn(coords1) # index correlation volume

            flow = coords1 - coords0
            with autocast(enabled=self.args.mixed_precision):
                # net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
                self.net, up_mask, delta_flow = self.update_block(self.net, self.inp, corr, flow)

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow

            # upsample predictions
            if up_mask is None:
                flow_up = upflow(coords1 - coords0, scale=image1.shape[2]//coords0.shape[2],)
            else:
                if self.args.fea_net in ["bigdx4"]:
                    flow_up = self.upsample_flow(coords1 - coords0, up_mask, upsample_scale=4)
                else:
                    flow_up = self.upsample_flow(coords1 - coords0, up_mask)
            
            flow_predictions.append(flow_up)

        if test_mode:
            return coords1 - coords0, flow_up
            
        return flow_predictions



================================================
FILE: model/HybridNet.py
================================================

import torch 
import torch.nn as nn 

from thirdparty.kpconv.kpconv_blocks import *
import torch.nn.functional as F
import numpy as np
from kpconv.lib.utils import square_distance
from model.descriptor2D import  SuperPoint2D
from model.descriptor3D import  KPSuperpoint3Dv2



REGISTERED_HYBRID_NET_CLASSES={}
def register_hybrid_net(cls, name=None):
    global REGISTERED_HYBRID_NET_CLASSES
    if name is None:
        name = cls.__name__
    assert name not in REGISTERED_HYBRID_NET_CLASSES, f"exist class: {REGISTERED_HYBRID_NET_CLASSES}"
    REGISTERED_HYBRID_NET_CLASSES[name] = cls
    return cls


def get_hybrid_net(name):
    global REGISTERED_HYBRID_NET_CLASSES
    assert name in REGISTERED_HYBRID_NET_CLASSES, f"available class: {REGISTERED_HYBRID_NET_CLASSES}"
    return REGISTERED_HYBRID_NET_CLASSES[name]

class ContextFeatureNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.context_fea_extractor_3d= KPSuperpoint3Dv2(config['context_fea_extractor_3d'] )
    
    def forward(self, batch):
        # x = batch['features'].clone().detach()
        # assert len(batch['stack_lengths'][-1])==1, "Only support bs=1 for now" 
        len_src_c = batch['stack_lengths'][-1][0]
        pcd_c = batch['model_points'][-1]
        pcd_c = pcd_c[:len_src_c]

        image=batch['image']

        ############### encode 3d and 2d features ###############
        batch3d={
            'points': batch['model_points'], 
            'neighbors': batch['neighbors'], 
            'pools':  batch['pools'], 
            'upsamples': batch['upsamples'],
            'features': batch['model_point_features'], 
            'stack_lengths': batch['stack_lengths'],
        }
        ctx_descriptors_3d = self.context_fea_extractor_3d(batch3d)


        return {
            "ctx_fea_3d":ctx_descriptors_3d,
        }



@register_hybrid_net
class HybridDescNet(nn.Module):
    #independent 2d and 3d network
    def __init__(self, config):
        super().__init__()

        self.corr_fea_extractor_2d= SuperPoint2D(config['keypoints_detector_2d'] )
        self.corr_fea_extractor_3d= KPSuperpoint3Dv2(config['keypoints_detector_3d'] )
        self.descriptors_3d = {}


    def forward(self, batch):
        assert len(set(batch['class_name']))==1, "A batch should contain data of the same class."
        class_name = batch['class_name'][0]

        len_src_c = batch['stack_lengths'][-1][0]
        pcd_c = batch['model_points'][-1]
        pcd_c = pcd_c[:len_src_c]#, pcd_c[len_src_c:]

        image=batch['image']

        ############### encode 3d and 2d features ###############
        batch3d={
            'points': batch['model_points'], 
            'neighbors': batch['neighbors'], 
            'pools':  batch['pools'], 
            'upsamples': batch['upsamples'],
            'features': batch['model_point_features'], 
            'stack_lengths': batch['stack_lengths'],
        }
        if self.training:
            self.descriptors_3d[class_name] = self.corr_fea_extractor_3d(batch3d)
        else:
            if class_name not in self.descriptors_3d:
                self.descriptors_3d[class_name] = self.corr_fea_extractor_3d(batch3d)

        descriptors_2d = self.corr_fea_extractor_2d(image)['descriptors']


        return {
            "descriptors_2d":descriptors_2d,
            "descriptors_3d":self.descriptors_3d[class_name],
            "scores_saliency_3d": None, 
            "scores_overlap_3d":None, 

        }


================================================
FILE: model/PoseRefiner.py
================================================
import os 
import time
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from easydict import EasyDict as edict
from functools import partial

from geometry.transformation import *
from geometry.intrinsics import *
from geometry.projective_ops import coords_grid, normalize_coords_grid
from model.CFNet import GRU_CFUpdator , ImageFeaEncoder
from utils.pose_utils import pose_padding
from config.default import get_cfg



EPS = 1e-5
MIN_DEPTH = 0.1
MAX_ERROR = 100.0

# exclude extremly large displacements
MAX_FLOW = 400


def raft_sequence_flow_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
    """ Loss function defined over sequence of flow predictions """

    n_predictions = len(flow_preds)    
    flow_loss = 0.0
    
    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    
    valid = (valid >= 0.5) & (mag < max_flow)

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)
        i_loss = (flow_preds[i] - flow_gt).abs()
        flow_loss += i_weight * (valid[:, None] * i_loss).mean()

    epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
    epe = epe.view(-1)[valid.view(-1)]

    metrics = {
        'epe': epe.mean().item(),
        '1px': (epe < 1).float().mean().item(),
        '3px': (epe < 3).float().mean().item(),
        '5px': (epe < 5).float().mean().item(),
    }

    return flow_loss, metrics




class PoseRefiner(nn.Module):
    def __init__(self, cfg,
                 reuse=False,
                 schedule=None,
                 use_regressor=True,
                 is_calibrated=True,
                 bn_is_training=False,
                 is_training=True,
                 renderer=None,
                 ):

        super().__init__()

        self.legacy=True
        self.cfg = cfg
        self.reuse = reuse
        self.sigma=nn.ParameterList( [nn.Parameter(torch.ones(1)*1 )] )
        self.with_corr_weight = self.cfg.get("with_corr_weight", True)
        if not self.with_corr_weight:
            print("Warning: the correlation weighting is disabled.")

        self.is_calibrated = cfg.IS_CALIBRATED
       
Download .txt
gitextract_p7a82okx/

├── .gitignore
├── LICENSE.md
├── README.md
├── builder/
│   ├── __init__.py
│   ├── dataset_builder.py
│   ├── input_reader_builder.py
│   ├── losses_builder.py
│   ├── lr_scheduler_builder.py
│   ├── optimizer_builder.py
│   └── rnnpose_builder.py
├── config/
│   ├── default.py
│   └── linemod/
│       ├── copy.sh
│       ├── copy_occ.sh
│       ├── template_fw0.5.yml
│       └── template_fw0.5_occ.yml
├── data/
│   ├── __init__.py
│   ├── dataset.py
│   ├── linemod/
│   │   └── linemod_config.py
│   ├── linemod_dataset.py
│   ├── preprocess.py
│   ├── transforms.py
│   └── ycb/
│       └── basic.py
├── doc/
│   └── prepare_data.md
├── docker/
│   ├── Dockerfile
│   └── freeze.yml
├── geometry/
│   ├── __init__.py
│   ├── cholesky.py
│   ├── diff_render.py
│   ├── diff_render_optim.py
│   ├── einsum.py
│   ├── intrinsics.py
│   ├── projective_ops.py
│   ├── se3.py
│   └── transformation.py
├── model/
│   ├── CFNet.py
│   ├── HybridNet.py
│   ├── PoseRefiner.py
│   ├── RNNPose.py
│   ├── descriptor2D.py
│   ├── descriptor3D.py
│   └── losses.py
├── scripts/
│   ├── compile_3rdparty.sh
│   ├── eval.sh
│   ├── eval_lmocc.sh
│   ├── run_dataformatter.sh
│   ├── run_datainfo_generation.sh
│   └── train.sh
├── thirdparty/
│   ├── kpconv/
│   │   ├── __init__.py
│   │   ├── cpp_wrappers/
│   │   │   ├── compile_wrappers.sh
│   │   │   ├── cpp_neighbors/
│   │   │   │   ├── build.bat
│   │   │   │   ├── neighbors/
│   │   │   │   │   ├── neighbors.cpp
│   │   │   │   │   └── neighbors.h
│   │   │   │   ├── setup.py
│   │   │   │   └── wrapper.cpp
│   │   │   ├── cpp_subsampling/
│   │   │   │   ├── build.bat
│   │   │   │   ├── grid_subsampling/
│   │   │   │   │   ├── grid_subsampling.cpp
│   │   │   │   │   └── grid_subsampling.h
│   │   │   │   ├── setup.py
│   │   │   │   └── wrapper.cpp
│   │   │   └── cpp_utils/
│   │   │       ├── cloud/
│   │   │       │   ├── cloud.cpp
│   │   │       │   └── cloud.h
│   │   │       └── nanoflann/
│   │   │           └── nanoflann.hpp
│   │   ├── kernels/
│   │   │   ├── dispositions/
│   │   │   │   └── k_015_center_3D.ply
│   │   │   └── kernel_points.py
│   │   ├── kpconv_blocks.py
│   │   └── lib/
│   │       ├── __init__.py
│   │       ├── ply.py
│   │       ├── timer.py
│   │       └── utils.py
│   ├── nn/
│   │   ├── _ext.c
│   │   ├── nn_utils.py
│   │   ├── setup.py
│   │   └── src/
│   │       ├── ext.h
│   │       └── nearest_neighborhood.cu
│   ├── raft/
│   │   ├── corr.py
│   │   ├── extractor.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── augmentor.py
│   │       ├── flow_viz.py
│   │       ├── frame_utils.py
│   │       └── utils.py
│   └── vsd/
│       └── inout.py
├── tools/
│   ├── eval.py
│   ├── generate_data_info_deepim_0_orig.py
│   ├── generate_data_info_deepim_1_syn.py
│   ├── generate_data_info_deepim_2_posecnnval.py
│   ├── generate_data_info_v2_deepim.py
│   ├── train.py
│   └── transform_data_format.py
├── torchplus/
│   ├── __init__.py
│   ├── metrics.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── functional.py
│   │   └── modules/
│   │       ├── __init__.py
│   │       ├── common.py
│   │       └── normalization.py
│   ├── ops/
│   │   ├── __init__.py
│   │   └── array_ops.py
│   ├── tools.py
│   └── train/
│       ├── __init__.py
│       ├── checkpoint.py
│       ├── common.py
│       ├── fastai_optim.py
│       ├── learning_schedules.py
│       ├── learning_schedules_fastai.py
│       └── optim.py
├── utils/
│   ├── __init__.py
│   ├── config_io.py
│   ├── distributed_utils.py
│   ├── eval_metric.py
│   ├── furthest_point_sample.py
│   ├── geometric.py
│   ├── img_utils.py
│   ├── log_tool.py
│   ├── pose_utils.py
│   ├── pose_utils_np.py
│   ├── progress_bar.py
│   ├── rand_utils.py
│   ├── singleton.py
│   ├── timer.py
│   ├── util.py
│   └── visualize.py
└── weights/
    ├── gru_update.pth
    ├── img_fea_enc.pth
    └── superpoint_v1.pth
Download .txt
SYMBOL INDEX (1087 symbols across 83 files)

FILE: builder/dataset_builder.py
  function build (line 8) | def build(input_reader_config,

FILE: builder/input_reader_builder.py
  class DatasetWrapper (line 7) | class DatasetWrapper(Dataset):
    method __init__ (line 11) | def __init__(self, dataset):
    method __len__ (line 14) | def __len__(self):
    method __getitem__ (line 17) | def __getitem__(self, idx):
    method dataset (line 21) | def dataset(self):
  function build (line 25) | def build(input_reader_config,

FILE: builder/losses_builder.py
  function build (line 4) | def build(loss_config):

FILE: builder/lr_scheduler_builder.py
  function build (line 6) | def build(optimizer_config, optimizer, total_step):
  function _create_learning_rate_scheduler (line 28) | def _create_learning_rate_scheduler(learning_rate_config, optimizer, tot...

FILE: builder/optimizer_builder.py
  function children (line 9) | def children(m: nn.Module):
  function num_children (line 14) | def num_children(m: nn.Module) -> int:
  function flatten_model (line 21) | def flatten_model(m):
  function get_layer_groups (line 29) | def get_layer_groups(m): return [nn.ModuleList(flatten_model(m))]
  function get_voxeLO_net_layer_groups (line 31) | def get_voxeLO_net_layer_groups(net):
  function get_voxeLO_net_layer_groups (line 44) | def get_voxeLO_net_layer_groups(net):
  function build (line 64) | def build(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):

FILE: builder/rnnpose_builder.py
  function build (line 6) | def build(model_cfg,

FILE: config/default.py
  function _merge_a_into_b (line 5) | def _merge_a_into_b(a, b):
  class Config (line 41) | class Config(metaclass=Singleton):
    method __init__ (line 42) | def __init__(self):
    method __get_item__ (line 62) | def __get_item__(self, key):
    method merge (line 65) | def merge(self, config_dict, sub_key=None):
  function get_cfg (line 78) | def get_cfg(Node=None):

FILE: data/dataset.py
  function register_dataset (line 12) | def register_dataset(cls, name=None):
  function get_dataset_class (line 21) | def get_dataset_class(name):
  class Dataset (line 27) | class Dataset(object):
    method __getitem__ (line 30) | def __getitem__(self, index):
    method __len__ (line 33) | def __len__(self):
    method _read_data (line 36) | def _read_data(self, query):
    method evaluation (line 40) | def evaluation(self, dt_annos, output_dir):

FILE: data/linemod_dataset.py
  function se3_q2m (line 31) | def se3_q2m(se3_q):
  function info_convertor (line 41) | def info_convertor(info,):
  function resize (line 60) | def resize(im, target_size, max_size, stride=0, interpolation=cv2.INTER_...
  function sample_poses (line 89) | def sample_poses(pose_tgt):
  class LinemodDeepIMSynRealV2 (line 124) | class LinemodDeepIMSynRealV2(Dataset):
    method __init__ (line 126) | def __init__(self, root_path,
    method load_random_background (line 205) | def load_random_background(self, im_observed, mask):
    method _read_data (line 259) | def _read_data(self, idx):
    method __getitem__ (line 421) | def __getitem__(self, idx):
    method __len__ (line 439) | def __len__(self):

FILE: data/preprocess.py
  function merge_batch (line 33) | def merge_batch(batch_list):
  function get_correspondences (line 84) | def get_correspondences(src_pcd, tgt_pcd, search_voxel_size, K=None, tra...
  function to_pcd (line 103) | def to_pcd(xyz):
  function to_tsfm (line 109) | def to_tsfm(rot, trans):
  function CameraIntrinsicUpdate (line 116) | def CameraIntrinsicUpdate(old_K, aug_param):
  function crop_transform (line 146) | def crop_transform(images, depths, Ks, crop_param, ):
  function patch_crop (line 181) | def patch_crop(image, depth, mask, K_old, margin_ratio=0.2, output_size=...
  function preprocess_deepim (line 257) | def preprocess_deepim(
  function preprocess (line 367) | def preprocess(
  function batch_grid_subsampling_kpconv (line 503) | def batch_grid_subsampling_kpconv(points, batches_len, features=None, la...
  function batch_neighbors_kpconv (line 544) | def batch_neighbors_kpconv(queries, supports, q_batches, s_batches, radi...
  function collate_fn_descriptor (line 564) | def collate_fn_descriptor(list_data, config, neighborhood_limits):
  function collate_fn_descriptor_deepim (line 707) | def collate_fn_descriptor_deepim(list_data, config, neighborhood_limits):
  function calibrate_neighbors (line 856) | def calibrate_neighbors(dataset, config, collate_fn, keep_ratio=0.8, sam...
  function get_dataloader (line 894) | def get_dataloader(dataset, kpconv_config, batch_size=1, num_workers=4, ...
  function get_dataloader_deepim (line 913) | def get_dataloader_deepim(dataset, kpconv_config, batch_size=1, num_work...

FILE: data/transforms.py
  class Compose (line 10) | class Compose(object):
    method __init__ (line 12) | def __init__(self, transforms):
    method __call__ (line 15) | def __call__(self, img, kpts=None, mask=None):
    method __repr__ (line 20) | def __repr__(self):
  class ToTensor (line 29) | class ToTensor(object):
    method __call__ (line 31) | def __call__(self, img, kpts, mask):
  class Normalize (line 35) | class Normalize(object):
    method __init__ (line 37) | def __init__(self, mean, std, to_bgr=True):
    method __call__ (line 42) | def __call__(self, img, kpts, mask):
  class ColorJitter (line 50) | class ColorJitter(object):
    method __init__ (line 52) | def __init__(self,
    method __call__ (line 64) | def __call__(self, image, kpts, mask):
  class RandomBlur (line 69) | class RandomBlur(object):
    method __init__ (line 71) | def __init__(self, prob=0.5):
    method __call__ (line 74) | def __call__(self, image, kpts, mask):
  function make_transforms (line 81) | def make_transforms(cfg, is_train):

FILE: geometry/cholesky.py
  class _cholesky_solve (line 9) | class _cholesky_solve(torch.autograd.Function):
    method forward (line 11) | def forward(ctx, H, b):
    method backward (line 20) | def backward(ctx, dx):
  function cholesky_solve (line 29) | def cholesky_solve(H, b):
  function solve (line 32) | def solve(H, b, max_update=1.0):
  function __test__ (line 54) | def __test__():

FILE: geometry/diff_render.py
  function rasterize (line 31) | def rasterize(R, T, meshes, rasterizer, blur_radius=0):
  function set_bary_coords_to_nearest (line 51) | def set_bary_coords_to_nearest(bary_coords_):
  class MeshRendererWithDepth (line 58) | class MeshRendererWithDepth(nn.Module):
    method __init__ (line 59) | def __init__(self, rasterizer, shader):
    method to (line 64) | def to(self, device):
    method forward (line 70) | def forward(self, meshes_world, **kwargs) -> torch.Tensor:
  class DiffRender (line 75) | class DiffRender(nn.Module):
    method __init__ (line 76) | def __init__(self, mesh_path, render_texture=False):
    method to (line 107) | def to(self, *args, **kwargs):
    method get_patch_center_depths (line 122) | def get_patch_center_depths(self, T, K):
    method forward_interpolate (line 143) | def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_ra...
    method render_mesh (line 160) | def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, light...
    method render_offset_map (line 202) | def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):
    method forward (line 224) | def forward(self, vert_attribute, T, K, render_image_size, near=0.1, f...
    method render_depth (line 276) | def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode=...
  class DiffRendererWrapper (line 319) | class DiffRendererWrapper(nn.Module):
    method __init__ (line 320) | def __init__(self, obj_paths, device="cuda", render_texture=False ):
    method get_patch_center_depths (line 332) | def get_patch_center_depths(self, model_names, T, K):
    method render_offset_map (line 345) | def render_offset_map(self, model_names,  T, K, render_image_size, nea...
    method render_pat_id (line 355) | def render_pat_id(self, model_names,  T, K, render_image_size, near=0....
    method render_depth (line 366) | def render_depth(self, model_names,  T, K, render_image_size, near=0.1...
    method forward (line 377) | def forward(self, model_names,  vert_attribute, T, K, render_image_siz...

FILE: geometry/diff_render_optim.py
  function rasterize (line 33) | def rasterize(R, T, meshes, rasterizer, blur_radius=0):
  function set_bary_coords_to_nearest (line 53) | def set_bary_coords_to_nearest(bary_coords_):
  class MeshRendererWithDepth (line 60) | class MeshRendererWithDepth(nn.Module):
    method __init__ (line 61) | def __init__(self, rasterizer, shader):
    method to (line 66) | def to(self, device):
    method forward (line 72) | def forward(self, meshes_world, **kwargs) -> torch.Tensor:
  class MeshRendererWithDepth_v2 (line 77) | class MeshRendererWithDepth_v2(nn.Module):
    method __init__ (line 78) | def __init__(self, rasterizer, shader):
    method to (line 84) | def to(self, *args, **kwargs):
    method forward (line 95) | def forward(self, meshes_world, **kwargs) -> torch.Tensor:
  class DiffRender (line 108) | class DiffRender(nn.Module):
    method __init__ (line 109) | def __init__(self, mesh_path, render_texture=False):
    method to (line 144) | def to(self, *args, **kwargs):
    method get_patch_center_depths (line 163) | def get_patch_center_depths(self, T, K):
    method forward_interpolate (line 185) | def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_ra...
    method render_mesh (line 201) | def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, light...
    method render_offset_map (line 246) | def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):
    method forward (line 269) | def forward(self, vert_attribute, T, K, render_image_size, near=0.1, f...
    method render_depth (line 327) | def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode=...
    method render_pointcloud (line 369) | def render_pointcloud(self, T, K, render_image_size, near=0.1, far=6):
  class DiffRendererWrapper (line 405) | class DiffRendererWrapper(nn.Module):
    method __init__ (line 406) | def __init__(self, obj_paths, device="cuda", render_texture=False ):
    method get_patch_center_depths (line 418) | def get_patch_center_depths(self, model_names, T, K):
    method render_offset_map (line 431) | def render_offset_map(self, model_names,  T, K, render_image_size, nea...
    method render_pat_id (line 441) | def render_pat_id(self, model_names,  T, K, render_image_size, near=0....
    method render_depth (line 453) | def render_depth(self, model_names,  T, K, render_image_size, near=0.1...
    method render_mesh (line 463) | def render_mesh(self, model_names,  T, K, render_image_size, near=0.1,...
    method render_pointcloud (line 474) | def render_pointcloud(self, model_names, T, K, render_image_size, near...
    method forward (line 483) | def forward(self, model_names,  vert_attribute, T, K, render_image_siz...

FILE: geometry/einsum.py
  function einsum (line 8) | def einsum(equation, *inputs):

FILE: geometry/intrinsics.py
  function intrinsics_vec_to_matrix (line 6) | def intrinsics_vec_to_matrix(kvec):
  function intrinsics_matrix_to_vec (line 15) | def intrinsics_matrix_to_vec(kmat):
  function update_intrinsics (line 22) | def update_intrinsics(intrinsics, delta_focal):
  function rescale_depth (line 35) | def rescale_depth(depth, downscale=4):
  function rescale_depth_and_intrinsics (line 41) | def rescale_depth_and_intrinsics(depth, intrinsics, downscale=4):
  function rescale_depths_and_intrinsics (line 47) | def rescale_depths_and_intrinsics(depth, intrinsics, downscale=4):

FILE: geometry/projective_ops.py
  function normalize_coords_grid (line 11) | def normalize_coords_grid(coords):
  function coords_grid (line 25) | def coords_grid(ref, homogeneous=True):
  function extract_and_reshape_intrinsics (line 47) | def extract_and_reshape_intrinsics(intrinsics, shape=None):
  function backproject (line 68) | def backproject(depth, intrinsics, jacobian=False, depth_coords=None):
  function project (line 103) | def project(points, intrinsics, jacobian=False):

FILE: geometry/se3.py
  function matdotv (line 12) | def matdotv(A,b):
  function hat (line 15) | def hat(a):
  function quaternion_rotate_point (line 30) | def quaternion_rotate_point(q, pt, eq=None):
  function quaternion_rotate_matrix (line 41) | def quaternion_rotate_matrix(q, mat, eq=None):
  function quaternion_inverse (line 52) | def quaternion_inverse(q):
  function quaternion_multiply (line 55) | def quaternion_multiply(a, b):
  function quaternion_to_matrix (line 68) | def quaternion_to_matrix(q):
  function rotation_matrix_to_quaternion (line 91) | def rotation_matrix_to_quaternion(R):
  function so3_expm_and_theta (line 111) | def so3_expm_and_theta(omega):
  function so3_logm_and_theta (line 133) | def so3_logm_and_theta(so3):
  function se3_expm (line 146) | def se3_expm(xi):
  function se3_logm (line 168) | def se3_logm(so3, t):
  function se3_matrix_inverse (line 194) | def se3_matrix_inverse(G):
  function _se3_matrix_expm_grad (line 212) | def _se3_matrix_expm_grad(grad):
  function _se3_matrix_expm_shape (line 224) | def _se3_matrix_expm_shape(op):
  function _se3_matrix_expm (line 228) | def _se3_matrix_expm(upsilon_omega):
  class SE3_Matrix_Expm (line 284) | class SE3_Matrix_Expm(torch.autograd.Function):
    method forward (line 287) | def forward(ctx, upsilon_omega):
    method backward (line 293) | def backward(ctx, grad_output):
  function se3_matrix_expm (line 299) | def se3_matrix_expm(upsilon_omega):
  function se3_matrix_increment (line 303) | def se3_matrix_increment(G, upsilon_omega):

FILE: geometry/transformation.py
  function clip_dangerous_gradients (line 23) | def clip_dangerous_gradients(x):
  function jac_local_perturb (line 27) | def jac_local_perturb(pt, fill=False):
  function cond_transform (line 49) | def cond_transform(cond, T1, T2):
  class SE3 (line 65) | class SE3:
    method __init__ (line 66) | def __init__(self, upsilon=None, matrix=None, so3=None, translation=No...
    method __call__ (line 78) | def __call__(self, pt, jacobian=False):
    method __mul__ (line 95) | def __mul__(self, other):
    method identity_ (line 102) | def identity_(self):
    method increment (line 110) | def increment(self, upsilon):
    method concat (line 117) | def concat(self, other, axis=0):
    method copy (line 124) | def copy(self, stop_gradients=False):
    method to_vec (line 136) | def to_vec(self):
    method inv (line 139) | def inv(self):
    method adj (line 146) | def adj(self):
    method logm (line 165) | def logm(self):
    method shape (line 168) | def shape(self):
    method matrix (line 177) | def matrix(self, fill=True):
    method transform (line 184) | def transform(self, depth, intrinsics, valid_mask=False, return3d=False):
    method induced_flow (line 200) | def induced_flow(self, depth, intrinsics, valid_mask=False):
    method depth_change (line 210) | def depth_change(self, depth, intrinsics):
    method identity (line 215) | def identity(self):
  class SE3Sequence (line 230) | class SE3Sequence(SE3):
    method __init__ (line 233) | def __init__(self, upsilon=None, matrix=None, so3=None, translation=No...
    method __call__ (line 238) | def __call__(self, pt, inds=None, jacobian=False):
    method gather (line 245) | def gather(self, inds):
    method reprojction_optim (line 265) | def reprojction_optim(self,
    method transform (line 319) | def transform(self, depth, intrinsics, valid_mask=False, return3d=False):

FILE: model/CFNet.py
  class autocast (line 17) | class autocast:
    method __init__ (line 18) | def __init__(self, enabled):
    method __enter__ (line 20) | def __enter__(self):
    method __exit__ (line 22) | def __exit__(self, *args):
  class ImageFeaEncoder (line 26) | class ImageFeaEncoder(nn.Module):
    method __init__ (line 27) | def __init__(self, input_dim=3, output_dim=256):
    method forward (line 41) | def forward(self, image1, image2):
  class GRU_CFUpdator (line 52) | class GRU_CFUpdator(nn.Module):
    method __init__ (line 53) | def __init__(self, args):
    method freeze_bn (line 81) | def freeze_bn(self):
    method initialize_flow (line 86) | def initialize_flow(self, img, downsample_rate=8):
    method upsample_flow (line 95) | def upsample_flow(self, flow, mask, upsample_scale=8):
    method forward (line 109) | def forward(self, fmap1, fmap2, iters=1, flow_init=None, upsample=True...

FILE: model/HybridNet.py
  function register_hybrid_net (line 15) | def register_hybrid_net(cls, name=None):
  function get_hybrid_net (line 24) | def get_hybrid_net(name):
  class ContextFeatureNet (line 29) | class ContextFeatureNet(nn.Module):
    method __init__ (line 30) | def __init__(self, config):
    method forward (line 34) | def forward(self, batch):
  class HybridDescNet (line 62) | class HybridDescNet(nn.Module):
    method __init__ (line 64) | def __init__(self, config):
    method forward (line 72) | def forward(self, batch):

FILE: model/PoseRefiner.py
  function raft_sequence_flow_loss (line 29) | def raft_sequence_flow_loss(flow_preds, flow_gt, valid, gamma=0.8, max_f...
  class PoseRefiner (line 60) | class PoseRefiner(nn.Module):
    method __init__ (line 61) | def __init__(self, cfg,
    method _clear (line 105) | def _clear(self,):
    method __len__ (line 116) | def __len__(self):
    method render (line 119) | def render(self, params, render_tex=False):
    method get_affine_transformation (line 145) | def get_affine_transformation(self, mask, crop_center, with_intrinsic_...
    method gen_zoom_crop_grids (line 207) | def gen_zoom_crop_grids(self, fg_mask, K, T, output_size, model_center...
    method forward (line 221) | def forward(self, image, Ts, intrinsics, fea_3d=None, Tj_gt=None, obj_...
    method compute_loss (line 378) | def compute_loss(self, Tij_gts, depths, intrinsics, loss='l1', log_err...

FILE: model/RNNPose.py
  function register_posenet (line 32) | def register_posenet(cls, name=None):
  function get_posenet_class (line 41) | def get_posenet_class(name):
  class RNNPose (line 50) | class RNNPose(nn.Module):
    method __init__ (line 51) | def __init__(self,
    method update_global_step (line 87) | def update_global_step(self):
    method get_global_step (line 90) | def get_global_step(self):
    method clear_global_step (line 93) | def clear_global_step(self):
    method sample_poses (line 96) | def sample_poses(self, pose_tgt):
    method _render_init (line 126) | def _render_init(self, config):
    method forward (line 157) | def forward(self, sample):
    method loss (line 225) | def loss(self, sample, preds_dict):

FILE: model/descriptor2D.py
  class SuperPoint2D (line 9) | class SuperPoint2D(nn.Module):
    method __init__ (line 27) | def __init__(self, config):
    method load_state_dict (line 100) | def load_state_dict(self,state_dict, strict=True):
    method forward_encoder (line 113) | def forward_encoder(self, x):
    method forward_decoder (line 134) | def forward_decoder(self, x, x_skip):
    method forward (line 166) | def forward(self, data):

FILE: model/descriptor3D.py
  class KPSuperpoint3Dv2 (line 10) | class KPSuperpoint3Dv2(nn.Module):
    method __init__ (line 12) | def __init__(self, config):
    method regular_score (line 142) | def regular_score(self,score):
    method forward_encoder (line 147) | def forward_encoder(self, batch):
    method forward_middle (line 172) | def forward_middle(self, x):
    method forward_decoder (line 180) | def forward_decoder(self, x, skip_x, batch ):
    method forward (line 195) | def forward(self, batch):

FILE: model/losses.py
  class Loss (line 21) | class Loss(nn.Module):
    method __init__ (line 25) | def __init__(self, loss_weight=1):
    method forward (line 30) | def forward(self,
    method _compute_loss (line 65) | def _compute_loss(self, prediction_tensor, target_tensor, **params):
  class L2Loss (line 82) | class L2Loss(Loss):
    method __init__ (line 84) | def __init__(self, loss_weight=1):
    method _compute_loss (line 87) | def _compute_loss(self, prediction_tensor, target_tensor, mask=None):
  class AdaptiveWeightedL2Loss (line 110) | class AdaptiveWeightedL2Loss(Loss):
    method __init__ (line 112) | def __init__(self, init_alpha, learn_alpha=True, loss_weight=1, focal_...
    method _compute_loss (line 120) | def _compute_loss(self, prediction_tensor, target_tensor, mask=None, a...
  class MetricLoss (line 158) | class MetricLoss(nn.Module):
    method __init__ (line 163) | def __init__(self, configs, log_scale=16, pos_optimal=0.1, neg_optimal...
    method get_circle_loss (line 179) | def get_circle_loss(self, coords_dist, feats_dist):
    method get_recall (line 222) | def get_recall(self, coords_dist, feats_dist):
    method get_weighted_bce_loss (line 236) | def get_weighted_bce_loss(self, prediction, gt):
    method forward (line 257) | def forward(self, src_pcd, tgt_pcd, src_feats, tgt_feats, corresponden...
  class PointAlignmentLoss (line 307) | class PointAlignmentLoss(nn.Module):
    method __init__ (line 308) | def __init__(self, loss_weight=1, ):
    method forward (line 312) | def forward(self, R_pred, t_pred, R_tgt, t_tgt, points):
    method _compute_loss (line 315) | def _compute_loss(self, R_pred, t_pred, R_tgt, t_tgt, points, ):

FILE: thirdparty/kpconv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp
  function brute_neighbors (line 5) | void brute_neighbors(vector<PointXYZ>& queries, vector<PointXYZ>& suppor...
  function ordered_neighbors (line 58) | void ordered_neighbors(vector<PointXYZ>& queries,
  function batch_ordered_neighbors (line 125) | void batch_ordered_neighbors(vector<PointXYZ>& queries,
  function batch_nanoflann_neighbors (line 211) | void batch_nanoflann_neighbors(vector<PointXYZ>& queries,

FILE: thirdparty/kpconv/cpp_wrappers/cpp_neighbors/wrapper.cpp
  type PyModuleDef (line 35) | struct PyModuleDef
  function PyMODINIT_FUNC (line 48) | PyMODINIT_FUNC PyInit_radius_neighbors(void)
  function PyObject (line 58) | static PyObject* batch_neighbors(PyObject* self, PyObject* args, PyObjec...

FILE: thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp
  function grid_subsampling (line 5) | void grid_subsampling(vector<PointXYZ>& original_points,
  function batch_grid_subsampling (line 109) | void batch_grid_subsampling(vector<PointXYZ>& original_points,

FILE: thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h
  function class (line 10) | class SampledData
  function update_all (line 42) | void update_all(const PointXYZ p, vector<float>::iterator f_begin, vecto...
  function update_features (line 55) | void update_features(const PointXYZ p, vector<float>::iterator f_begin)
  function update_classes (line 62) | void update_classes(const PointXYZ p, vector<int>::iterator l_begin)
  function update_points (line 74) | void update_points(const PointXYZ p)

FILE: thirdparty/kpconv/cpp_wrappers/cpp_subsampling/wrapper.cpp
  type PyModuleDef (line 39) | struct PyModuleDef
  function PyMODINIT_FUNC (line 52) | PyMODINIT_FUNC PyInit_grid_subsampling(void)
  function PyObject (line 62) | static PyObject* batch_subsampling(PyObject* self, PyObject* args, PyObj...
  function PyObject (line 338) | static PyObject* cloud_subsampling(PyObject* self, PyObject* args, PyObj...

FILE: thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.cpp
  function PointXYZ (line 27) | PointXYZ max_point(std::vector<PointXYZ> points)
  function PointXYZ (line 48) | PointXYZ min_point(std::vector<PointXYZ> points)

FILE: thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.h
  function class (line 40) | class PointXYZ
  function const (line 58) | float operator [] (int i) const
  function dot (line 66) | float dot(const PointXYZ P) const
  function sq_norm (line 71) | float sq_norm()
  function PointXYZ (line 76) | PointXYZ cross(const PointXYZ P) const
  function PointXYZ (line 110) | inline PointXYZ operator + (const PointXYZ A, const PointXYZ B)
  function PointXYZ (line 115) | inline PointXYZ operator - (const PointXYZ A, const PointXYZ B)
  function PointXYZ (line 120) | inline PointXYZ operator * (const PointXYZ P, const float a)
  function PointXYZ (line 125) | inline PointXYZ operator * (const float a, const PointXYZ P)
  function operator (line 135) | inline bool operator == (const PointXYZ A, const PointXYZ B)
  function PointXYZ (line 140) | inline PointXYZ floor(const PointXYZ P)
  function kdtree_get_pt (line 150) | struct PointCloud

FILE: thirdparty/kpconv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp
  function T (line 79) | T pi_const() {
  type has_resize (line 87) | struct has_resize : std::false_type {}
  type has_resize<T, decltype((void)std::declval<T>().resize(1), 0)> (line 90) | struct has_resize<T, decltype((void)std::declval<T>().resize(1), 0)>
  type has_assign (line 93) | struct has_assign : std::false_type {}
  type has_assign<T, decltype((void)std::declval<T>().assign(1, 0), 0)> (line 96) | struct has_assign<T, decltype((void)std::declval<T>().assign(1, 0), 0)>
  function resize (line 103) | inline typename std::enable_if<has_resize<Container>::value, void>::type
  function resize (line 113) | inline typename std::enable_if<!has_resize<Container>::value, void>::type
  function assign (line 123) | inline typename std::enable_if<has_assign<Container>::value, void>::type
  function assign (line 132) | inline typename std::enable_if<!has_assign<Container>::value, void>::type
  class KNNResultSet (line 142) | class KNNResultSet {
    method KNNResultSet (line 155) | inline KNNResultSet(CountType capacity_)
    method init (line 158) | inline void init(IndexType *indices_, DistanceType *dists_) {
    method CountType (line 166) | inline CountType size() const { return count; }
    method full (line 168) | inline bool full() const { return count == capacity; }
    method addPoint (line 175) | inline bool addPoint(DistanceType dist, IndexType index) {
    type IndexDist_Sorter (line 208) | struct IndexDist_Sorter {
    class RadiusResultSet (line 220) | class RadiusResultSet {
      method RadiusResultSet (line 230) | inline RadiusResultSet(
      method init (line 237) | inline void init() { clear(); }
      method clear (line 238) | inline void clear() { m_indices_dists.clear(); }
      method size (line 240) | inline size_t size() const { return m_indices_dists.size(); }
      method full (line 242) | inline bool full() const { return true; }
      method addPoint (line 249) | inline bool addPoint(DistanceType dist, IndexType index) {
      method DistanceType (line 255) | inline DistanceType worstDist() const { return radius; }
      method worst_item (line 261) | std::pair<IndexType, DistanceType> worst_item() const {
    method save_value (line 279) | void save_value(FILE *stream, const T &value, size_t count = 1) {
    method save_value (line 284) | void save_value(FILE *stream, const std::vector<T> &value) {
    method load_value (line 291) | void load_value(FILE *stream, T &value, size_t count = 1) {
    method load_value (line 298) | void load_value(FILE *stream, std::vector<T> &value) {
    type Metric (line 315) | struct Metric {}
    type L1_Adaptor (line 324) | struct L1_Adaptor {
      method L1_Adaptor (line 330) | L1_Adaptor(const DataSource &_data_source) : data_source(_data_sourc...
      method DistanceType (line 332) | inline DistanceType evalMetric(const T *a, const size_t b_idx, size_...
      method DistanceType (line 363) | inline DistanceType accum_dist(const U a, const V b, const size_t) c...
    type L2_Adaptor (line 375) | struct L2_Adaptor {
      method L2_Adaptor (line 381) | L2_Adaptor(const DataSource &_data_source) : data_source(_data_sourc...
      method DistanceType (line 383) | inline DistanceType evalMetric(const T *a, const size_t b_idx, size_...
      method DistanceType (line 411) | inline DistanceType accum_dist(const U a, const V b, const size_t) c...
    type L2_Simple_Adaptor (line 423) | struct L2_Simple_Adaptor {
      method L2_Simple_Adaptor (line 429) | L2_Simple_Adaptor(const DataSource &_data_source)
      method DistanceType (line 432) | inline DistanceType evalMetric(const T *a, const size_t b_idx,
      method DistanceType (line 443) | inline DistanceType accum_dist(const U a, const V b, const size_t) c...
    type SO2_Adaptor (line 455) | struct SO2_Adaptor {
      method SO2_Adaptor (line 461) | SO2_Adaptor(const DataSource &_data_source) : data_source(_data_sour...
      method DistanceType (line 463) | inline DistanceType evalMetric(const T *a, const size_t b_idx,
      method DistanceType (line 471) | inline DistanceType accum_dist(const U a, const V b, const size_t) c...
    type SO3_Adaptor (line 489) | struct SO3_Adaptor {
      method SO3_Adaptor (line 495) | SO3_Adaptor(const DataSource &_data_source)
      method DistanceType (line 498) | inline DistanceType evalMetric(const T *a, const size_t b_idx,
      method DistanceType (line 504) | inline DistanceType accum_dist(const U a, const V b, const size_t id...
    type metric_L1 (line 510) | struct metric_L1 : public Metric {
      type traits (line 511) | struct traits {
    type metric_L2 (line 516) | struct metric_L2 : public Metric {
      type traits (line 517) | struct traits {
    type metric_L2_Simple (line 522) | struct metric_L2_Simple : public Metric {
      type traits (line 523) | struct traits {
    type metric_SO2 (line 528) | struct metric_SO2 : public Metric {
      type traits (line 529) | struct traits {
    type metric_SO3 (line 534) | struct metric_SO3 : public Metric {
      type traits (line 535) | struct traits {
    type KDTreeSingleIndexAdaptorParams (line 546) | struct KDTreeSingleIndexAdaptorParams {
      method KDTreeSingleIndexAdaptorParams (line 547) | KDTreeSingleIndexAdaptorParams(size_t _leaf_max_size = 10)
    type SearchParams (line 554) | struct SearchParams {
      method SearchParams (line 557) | SearchParams(int checks_IGNORED_ = 32, float eps_ = 0, bool sorted_ ...
    method T (line 578) | inline T *allocate(size_t count = 1) {
    class PooledAllocator (line 601) | class PooledAllocator {
      method internal_init (line 612) | void internal_init() {
      method PooledAllocator (line 626) | PooledAllocator() { internal_init(); }
      method free_all (line 634) | void free_all() {
      method T (line 702) | T *allocate(const size_t count = 1) {
    type array_or_vector_selector (line 715) | struct array_or_vector_selector {
    type array_or_vector_selector<-1, T> (line 719) | struct array_or_vector_selector<-1, T> {
    class KDTreeBaseClass (line 739) | class KDTreeBaseClass {
      method freeIndex (line 744) | void freeIndex(Derived &obj) {
      type Node (line 754) | struct Node {
        type leaf (line 758) | struct leaf {
        type nonleaf (line 761) | struct nonleaf {
      type Interval (line 771) | struct Interval {
      method size (line 813) | size_t size(const Derived &obj) const { return obj.m_size; }
      method veclen (line 816) | size_t veclen(const Derived &obj) {
      method ElementType (line 821) | inline ElementType dataset_get(const Derived &obj, size_t idx,
      method usedMemory (line 830) | size_t usedMemory(Derived &obj) {
      method computeMinMax (line 836) | void computeMinMax(const Derived &obj, IndexType *ind, IndexType count,
      method NodePtr (line 857) | NodePtr divideTree(Derived &obj, const IndexType left, const IndexTy...
      method middleSplit_ (line 909) | void middleSplit_(Derived &obj, IndexType *ind, IndexType count,
      method planeSplit (line 967) | void planeSplit(Derived &obj, IndexType *ind, const IndexType count,
      method DistanceType (line 1005) | DistanceType computeInitialDistances(const Derived &obj,
      method save_tree (line 1024) | void save_tree(Derived &obj, FILE *stream, NodePtr tree) {
      method load_tree (line 1034) | void load_tree(Derived &obj, FILE *stream, NodePtr &tree) {
      method saveIndex_ (line 1050) | void saveIndex_(Derived &obj, FILE *stream) {
      method loadIndex_ (line 1064) | void loadIndex_(Derived &obj, FILE *stream) {
    class KDTreeSingleIndexAdaptor (line 1116) | class KDTreeSingleIndexAdaptor
      method KDTreeSingleIndexAdaptor (line 1122) | KDTreeSingleIndexAdaptor(
      method KDTreeSingleIndexAdaptor (line 1170) | KDTreeSingleIndexAdaptor(const int dimensionality,
      method buildIndex (line 1190) | void buildIndex() {
      method findNeighbors (line 1221) | bool findNeighbors(RESULTSET &result, const ElementType *vec,
      method knnSearch (line 1254) | size_t knnSearch(const ElementType *query_point, const size_t num_cl...
      method radiusSearch (line 1279) | size_t
      method radiusSearchCustomCallback (line 1297) | size_t radiusSearchCustomCallback(
      method init_vind (line 1309) | void init_vind() {
      method computeBoundingBox (line 1318) | void computeBoundingBox(BoundingBox &bbox) {
      method searchLevel (line 1348) | bool searchLevel(RESULTSET &result_set, const ElementType *vec,
      method saveIndex (line 1419) | void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); }
      method loadIndex (line 1426) | void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); }
    class KDTreeSingleIndexDynamicAdaptor_ (line 1468) | class KDTreeSingleIndexDynamicAdaptor_
      method KDTreeSingleIndexDynamicAdaptor_ (line 1519) | KDTreeSingleIndexDynamicAdaptor_(
      method KDTreeSingleIndexDynamicAdaptor_ (line 1536) | KDTreeSingleIndexDynamicAdaptor_
      method buildIndex (line 1555) | void buildIndex() {
      method findNeighbors (line 1584) | bool findNeighbors(RESULTSET &result, const ElementType *vec,
      method knnSearch (line 1616) | size_t knnSearch(const ElementType *query_point, const size_t num_cl...
      method radiusSearch (line 1641) | size_t
      method radiusSearchCustomCallback (line 1659) | size_t radiusSearchCustomCallback(
      method computeBoundingBox (line 1669) | void computeBoundingBox(BoundingBox &bbox) {
      method searchLevel (line 1699) | void searchLevel(RESULTSET &result_set, const ElementType *vec,
      method saveIndex (line 1765) | void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); }
      method loadIndex (line 1772) | void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); }
    class KDTreeSingleIndexDynamicAdaptor (line 1791) | class KDTreeSingleIndexDynamicAdaptor {
      method First0Bit (line 1825) | int First0Bit(IndexType num) {
      method init (line 1835) | void init() {
      method KDTreeSingleIndexDynamicAdaptor (line 1860) | KDTreeSingleIndexDynamicAdaptor(const int dimensionality,
      method KDTreeSingleIndexDynamicAdaptor (line 1881) | KDTreeSingleIndexDynamicAdaptor(
      method addPoints (line 1886) | void addPoints(IndexType start, IndexType end) {
      method removePoint (line 1909) | void removePoint(size_t idx) {
      method findNeighbors (line 1929) | bool findNeighbors(RESULTSET &result, const ElementType *vec,
    type KDTreeEigenMatrixAdaptor (line 1957) | struct KDTreeEigenMatrixAdaptor {
      method KDTreeEigenMatrixAdaptor (line 1971) | KDTreeEigenMatrixAdaptor(const size_t dimensionality,
      method KDTreeEigenMatrixAdaptor (line 1990) | KDTreeEigenMatrixAdaptor(const self_t &) = delete;
      method query (line 2002) | inline void query(const num_t *query_point, const size_t num_closest,
      method self_t (line 2013) | const self_t &derived() const { return *this; }
      method self_t (line 2014) | self_t &derived() { return *this; }
      method kdtree_get_point_count (line 2017) | inline size_t kdtree_get_point_count() const {
      method num_t (line 2022) | inline num_t kdtree_get_pt(const IndexType idx, size_t dim) const {
      method kdtree_get_bbox (line 2031) | bool kdtree_get_bbox(BBOX & /*bb*/) const {

FILE: thirdparty/kpconv/kernels/kernel_points.py
  function create_3D_rotations (line 33) | def create_3D_rotations(axis, angle):
  function spherical_Lloyd (line 67) | def spherical_Lloyd(radius, num_cells, dimension=3, fixed='center', appr...
  function kernel_point_optimization_debug (line 247) | def kernel_point_optimization_debug(radius, num_points, num_kernels=1, d...
  function load_kernels (line 391) | def load_kernels(radius, num_kpoints, dimension, fixed, lloyd=False):

FILE: thirdparty/kpconv/kpconv_blocks.py
  function gather (line 29) | def gather(x, idx, method=2):
  function radius_gaussian (line 63) | def radius_gaussian(sq_r, sig, eps=1e-9):
  function closest_pool (line 73) | def closest_pool(x, inds):
  function max_pool (line 88) | def max_pool(x, inds):
  function global_average (line 107) | def global_average(x, batch_lengths):
  class KPConv (line 137) | class KPConv(nn.Module):
    method __init__ (line 139) | def __init__(self, kernel_size, p_dim, in_channels, out_channels, KP_e...
    method reset_parameters (line 210) | def reset_parameters(self):
    method init_KP (line 216) | def init_KP(self):
    method forward (line 231) | def forward(self, q_pts, s_pts, neighb_inds, x):
    method __repr__ (line 379) | def __repr__(self):
  function block_decider (line 390) | def block_decider(block_name,
  class BatchNormBlock (line 442) | class BatchNormBlock(nn.Module):
    method __init__ (line 444) | def __init__(self, in_dim, use_bn, bn_momentum):
    method reset_parameters (line 462) | def reset_parameters(self):
    method forward (line 465) | def forward(self, x):
    method __repr__ (line 476) | def __repr__(self):
  class UnaryBlock (line 482) | class UnaryBlock(nn.Module):
    method __init__ (line 484) | def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):
    method forward (line 505) | def forward(self, x, batch=None):
    method __repr__ (line 512) | def __repr__(self):
  class LastUnaryBlock (line 519) | class LastUnaryBlock(nn.Module):
    method __init__ (line 521) | def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):
    method forward (line 536) | def forward(self, x, batch=None):
    method __repr__ (line 540) | def __repr__(self):
  class SimpleBlock (line 545) | class SimpleBlock(nn.Module):
    method __init__ (line 547) | def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, con...
    method forward (line 587) | def forward(self, x, batch):
  class ResnetBottleneckBlock (line 602) | class ResnetBottleneckBlock(nn.Module):
    method __init__ (line 604) | def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, con...
    method forward (line 659) | def forward(self, features, batch):
  class GlobalAverageBlock (line 690) | class GlobalAverageBlock(nn.Module):
    method __init__ (line 692) | def __init__(self):
    method forward (line 699) | def forward(self, x, batch):
  class NearestUpsampleBlock (line 703) | class NearestUpsampleBlock(nn.Module):
    method __init__ (line 705) | def __init__(self, layer_ind):
    method forward (line 713) | def forward(self, x, batch):
    method __repr__ (line 716) | def __repr__(self):
  class MaxPoolBlock (line 721) | class MaxPoolBlock(nn.Module):
    method __init__ (line 723) | def __init__(self, layer_ind):
    method forward (line 731) | def forward(self, x, batch):

FILE: thirdparty/kpconv/lib/ply.py
  function parse_header (line 60) | def parse_header(plyfile, ext):
  function parse_mesh_header (line 80) | def parse_mesh_header(plyfile, ext):
  function read_ply (line 113) | def read_ply(filename, triangular_mesh=False):
  function header_properties (line 195) | def header_properties(field_list, field_names):
  function write_ply (line 212) | def write_ply(filename, field_list, field_names, triangular_faces=None):
  function describe_element (line 326) | def describe_element(name, df):

FILE: thirdparty/kpconv/lib/timer.py
  class AverageMeter (line 4) | class AverageMeter(object):
    method __init__ (line 7) | def __init__(self):
    method reset (line 10) | def reset(self):
    method update (line 17) | def update(self, val, n=1):
  class Timer (line 26) | class Timer(object):
    method __init__ (line 29) | def __init__(self):
    method reset (line 36) | def reset(self):
    method tic (line 43) | def tic(self):
    method toc (line 48) | def toc(self, average=True):

FILE: thirdparty/kpconv/lib/utils.py
  class Logger (line 20) | class Logger:
    method __init__ (line 21) | def __init__(self, path):
    method write (line 25) | def write(self, text):
    method close (line 29) | def close(self):
  function save_obj (line 32) | def save_obj(obj, path ):
  function load_obj (line 39) | def load_obj(path):
  function load_config (line 46) | def load_config(path):
  function setup_seed (line 63) | def setup_seed(seed):
  function square_distance (line 73) | def square_distance(src, dst, normalised = False):
  function validate_gradient (line 96) | def validate_gradient(model):
  function natural_key (line 109) | def natural_key(string_):

FILE: thirdparty/nn/_ext.c
  type _cffi_global_s (line 127) | struct _cffi_global_s {
  type _cffi_getconst_s (line 135) | struct _cffi_getconst_s {
  type _cffi_struct_union_s (line 141) | struct _cffi_struct_union_s {
  type _cffi_field_s (line 157) | struct _cffi_field_s {
  type _cffi_enum_s (line 164) | struct _cffi_enum_s {
  type _cffi_typename_s (line 171) | struct _cffi_typename_s {
  type _cffi_type_context_s (line 177) | struct _cffi_type_context_s {
  type _cffi_parse_info_s (line 193) | struct _cffi_parse_info_s {
  type _cffi_externpy_s (line 201) | struct _cffi_externpy_s {
  type _cffi_parse_info_s (line 208) | struct _cffi_parse_info_s
  type _cffi_type_context_s (line 209) | struct _cffi_type_context_s
  type _cffi_type_context_s (line 211) | struct _cffi_type_context_s
  type __int8 (line 221) | typedef __int8 int8_t;
  type __int16 (line 222) | typedef __int16 int16_t;
  type __int32 (line 223) | typedef __int32 int32_t;
  type __int64 (line 224) | typedef __int64 int64_t;
  type __int8 (line 229) | typedef __int8 int_least8_t;
  type __int16 (line 230) | typedef __int16 int_least16_t;
  type __int32 (line 231) | typedef __int32 int_least32_t;
  type __int64 (line 232) | typedef __int64 int_least64_t;
  type uint_least8_t (line 233) | typedef unsigned __int8 uint_least8_t;
  type uint_least16_t (line 234) | typedef unsigned __int16 uint_least16_t;
  type uint_least32_t (line 235) | typedef unsigned __int32 uint_least32_t;
  type uint_least64_t (line 236) | typedef unsigned __int64 uint_least64_t;
  type __int8 (line 237) | typedef __int8 int_fast8_t;
  type __int16 (line 238) | typedef __int16 int_fast16_t;
  type __int32 (line 239) | typedef __int32 int_fast32_t;
  type __int64 (line 240) | typedef __int64 int_fast64_t;
  type uint_fast8_t (line 241) | typedef unsigned __int8 uint_fast8_t;
  type uint_fast16_t (line 242) | typedef unsigned __int16 uint_fast16_t;
  type uint_fast32_t (line 243) | typedef unsigned __int32 uint_fast32_t;
  type uint_fast64_t (line 244) | typedef unsigned __int64 uint_fast64_t;
  type __int64 (line 245) | typedef __int64 intmax_t;
  type uintmax_t (line 246) | typedef unsigned __int64 uintmax_t;
  type _Bool (line 252) | typedef unsigned char _Bool;
  type _Bool (line 270) | typedef bool _Bool;
  type _cffi_ctypedescr (line 374) | struct _cffi_ctypedescr
  function PyObject (line 382) | static PyObject *_cffi_init(const char *module_name, Py_ssize_t version,
  type wchar_t (line 415) | typedef wchar_t _cffi_wchar_t;
  type _cffi_wchar_t (line 417) | typedef uint16_t _cffi_wchar_t;
  function _CFFI_UNUSED_FN (line 420) | _CFFI_UNUSED_FN static uint16_t _cffi_to_c_char16_t(PyObject *o)
  function _CFFI_UNUSED_FN (line 428) | _CFFI_UNUSED_FN static PyObject *_cffi_from_c_char16_t(uint16_t x)
  function _CFFI_UNUSED_FN (line 436) | _CFFI_UNUSED_FN static int _cffi_to_c_char32_t(PyObject *o)
  function _CFFI_UNUSED_FN (line 444) | _CFFI_UNUSED_FN static PyObject *_cffi_from_c_char32_t(int x)
  type _cffi_externpy_s (line 456) | struct _cffi_externpy_s
  function _cffi_d_findNearestPointIdxLauncher (line 513) | static void _cffi_d_findNearestPointIdxLauncher(float * x0, float * x1, ...
  function PyObject (line 518) | static PyObject *
  type _cffi_global_s (line 609) | struct _cffi_global_s
  type _cffi_type_context_s (line 613) | struct _cffi_type_context_s
  function PyMODINIT_FUNC (line 634) | PyMODINIT_FUNC
  function PyInit__ext (line 646) | PyInit__ext(void) { return NULL; }
  function init_ext (line 648) | init_ext(void) { }
  function PyMODINIT_FUNC (line 652) | PyMODINIT_FUNC
  function PyMODINIT_FUNC (line 658) | PyMODINIT_FUNC

FILE: thirdparty/nn/nn_utils.py
  function find_nearest_point_idx (line 6) | def find_nearest_point_idx(ref_pts, que_pts):

FILE: thirdparty/raft/corr.py
  class CorrBlock (line 12) | class CorrBlock:
    method __init__ (line 13) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, downsample_ra...
    method __call__ (line 36) | def __call__(self, coords):
    method corr (line 60) | def corr(fmap1, fmap2):
  class AlternateCorrBlock (line 70) | class AlternateCorrBlock:
    method __init__ (line 71) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 81) | def __call__(self, coords):

FILE: thirdparty/raft/extractor.py
  class ResidualBlock (line 6) | class ResidualBlock(nn.Module):
    method __init__ (line 7) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 48) | def forward(self, x):
  class BottleneckBlock (line 60) | class BottleneckBlock(nn.Module):
    method __init__ (line 61) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 107) | def forward(self, x):
  class BasicEncoder (line 118) | class BasicEncoder(nn.Module):
    method __init__ (line 119) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, input...
    method _make_layer (line 183) | def _make_layer(self, dim, stride=1):
    method forward (line 192) | def forward(self, x):
  class BasicEncoder_dx4 (line 234) | class BasicEncoder_dx4(nn.Module):
    method __init__ (line 235) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 276) | def _make_layer(self, dim, stride=1):
    method forward (line 285) | def forward(self, x):
  class SmallEncoder (line 312) | class SmallEncoder(nn.Module):
    method __init__ (line 313) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 352) | def _make_layer(self, dim, stride=1):
    method forward (line 361) | def forward(self, x):
  class SmallEncoder_dx4 (line 386) | class SmallEncoder_dx4(nn.Module):
    method __init__ (line 387) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 427) | def _make_layer(self, dim, stride=1):
    method forward (line 436) | def forward(self, x):

FILE: thirdparty/raft/update.py
  class FlowHead (line 6) | class FlowHead(nn.Module):
    method __init__ (line 7) | def __init__(self, input_dim=128, hidden_dim=256):
    method forward (line 13) | def forward(self, x):
  class ConvGRU (line 16) | class ConvGRU(nn.Module):
    method __init__ (line 17) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 23) | def forward(self, h, x):
  class SepConvGRU (line 33) | class SepConvGRU(nn.Module):
    method __init__ (line 34) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 45) | def forward(self, h, x):
  class SmallMotionEncoder (line 62) | class SmallMotionEncoder(nn.Module):
    method __init__ (line 63) | def __init__(self, args):
    method forward (line 71) | def forward(self, flow, corr):
  class BasicMotionEncoder (line 79) | class BasicMotionEncoder(nn.Module):
    method __init__ (line 80) | def __init__(self, args):
    method forward (line 89) | def forward(self, flow, corr):
  class BasicMotionEncoderGeo (line 99) | class BasicMotionEncoderGeo(nn.Module):
    method __init__ (line 100) | def __init__(self, args):
    method forward (line 111) | def forward(self, flow, corr, corr_geo):
  class SmallUpdateBlock (line 128) | class SmallUpdateBlock(nn.Module):
    method __init__ (line 129) | def __init__(self, args, hidden_dim=96):
    method forward (line 135) | def forward(self, net, inp, corr, flow):
  class SmallUpdateBlockUpMask (line 143) | class SmallUpdateBlockUpMask(nn.Module):
    method __init__ (line 144) | def __init__(self, args, hidden_dim=96):
    method forward (line 154) | def forward(self, net, inp, corr, flow):
  class BasicUpdateBlock (line 164) | class BasicUpdateBlock(nn.Module):
    method __init__ (line 165) | def __init__(self, args, hidden_dim=128, input_dim=128, downsample_sca...
    method forward (line 178) | def forward(self, net, inp, corr, flow, upsample=True):
  class BasicUpdateBlockGeo (line 190) | class BasicUpdateBlockGeo(nn.Module):
    method __init__ (line 192) | def __init__(self, args, hidden_dim=128, input_dim=128, downsample_sca...
    method forward (line 205) | def forward(self, net, inp, corr, geo_corr, flow, upsample=True):

FILE: thirdparty/raft/utils/augmentor.py
  class FlowAugmentor (line 15) | class FlowAugmentor:
    method __init__ (line 16) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=T...
    method color_transform (line 36) | def color_transform(self, img1, img2):
    method eraser_transform (line 52) | def eraser_transform(self, img1, img2, bounds=[50, 100]):
    method spatial_transform (line 67) | def spatial_transform(self, img1, img2, flow):
    method __call__ (line 111) | def __call__(self, img1, img2, flow):
  class SparseFlowAugmentor (line 122) | class SparseFlowAugmentor:
    method __init__ (line 123) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=F...
    method color_transform (line 142) | def color_transform(self, img1, img2):
    method eraser_transform (line 148) | def eraser_transform(self, img1, img2):
    method resize_sparse_flow_map (line 161) | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
    method spatial_transform (line 195) | def spatial_transform(self, img1, img2, flow, valid):
    method __call__ (line 236) | def __call__(self, img1, img2, flow, valid):

FILE: thirdparty/raft/utils/flow_viz.py
  function make_colorwheel (line 20) | def make_colorwheel():
  function flow_uv_to_colors (line 70) | def flow_uv_to_colors(u, v, convert_to_bgr=False):
  function flow_to_image (line 109) | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):

FILE: thirdparty/raft/utils/frame_utils.py
  function readFlow (line 12) | def readFlow(fn):
  function readPFM (line 33) | def readPFM(file):
  function writeFlow (line 70) | def writeFlow(filename,uv,v=None):
  function readFlowKITTI (line 102) | def readFlowKITTI(filename):
  function readDispKITTI (line 109) | def readDispKITTI(filename):
  function writeFlowKITTI (line 116) | def writeFlowKITTI(filename, uv):
  function read_gen (line 123) | def read_gen(file_name, pil=False):

FILE: thirdparty/raft/utils/utils.py
  class InputPadder (line 7) | class InputPadder:
    method __init__ (line 9) | def __init__(self, dims, mode='sintel'):
    method pad (line 18) | def pad(self, *inputs):
    method unpad (line 21) | def unpad(self,x):
  function forward_interpolate (line 26) | def forward_interpolate(flow):
  function bilinear_sampler (line 57) | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
  function coords_grid (line 74) | def coords_grid(batch, ht, wd):
  function upflow8 (line 80) | def upflow8(flow, mode='bilinear'):
  function upflow (line 84) | def upflow(flow, mode='bilinear', scale=8):

FILE: thirdparty/vsd/inout.py
  function load_depth (line 10) | def load_depth(path):
  function load_ply (line 16) | def load_ply(path):

FILE: tools/eval.py
  function load_example_to_device (line 45) | def load_example_to_device(example,
  function build_network (line 60) | def build_network(model_cfg, measure_time=False, testing=False):
  function _worker_init_fn (line 66) | def _worker_init_fn(worker_id):
  function freeze_params (line 73) | def freeze_params(params: dict, include: str = None, exclude: str = None):
  function freeze_params_v2 (line 93) | def freeze_params_v2(params: dict, include: str = None, exclude: str = N...
  function filter_param_dict (line 110) | def filter_param_dict(state_dict: dict, include: str = None, exclude: st...
  function chk_rank (line 129) | def chk_rank(rank_, use_dist=False):
  function get_rank (line 139) | def get_rank(use_dist=False):
  function get_world (line 149) | def get_world(use_dist):
  function get_ngpus_per_node (line 157) | def get_ngpus_per_node():
  function multi_proc_train (line 162) | def multi_proc_train(
  function train_worker (line 227) | def train_worker(rank, params):
  function eval (line 257) | def eval(

FILE: tools/generate_data_info_deepim_0_orig.py
  function parse_pose_file (line 11) | def parse_pose_file(file):
  function parse_calib_file (line 24) | def parse_calib_file(file):
  function create_data_info (line 41) | def create_data_info(data_root, saving_path, with_assertion=True):

FILE: tools/generate_data_info_deepim_1_syn.py
  function parse_pose_file (line 11) | def parse_pose_file(file):
  function parse_calib_file (line 24) | def parse_calib_file(file):
  function create_data_info (line 38) | def create_data_info(data_root, saving_path, with_assertion=False ):

FILE: tools/generate_data_info_deepim_2_posecnnval.py
  function parse_pose_file (line 13) | def parse_pose_file(file):
  function parse_calib_file (line 26) | def parse_calib_file(file):
  function create_data_info (line 40) | def create_data_info(data_root, saving_path, with_assertion=True):

FILE: tools/generate_data_info_v2_deepim.py
  function parse_pose_file (line 11) | def parse_pose_file(file):
  function parse_calib_file (line 24) | def parse_calib_file(file):
  function create_data_info (line 40) | def create_data_info(data_root, saving_path, training_data_ratio=0.8, sh...

FILE: tools/train.py
  function load_example_to_device (line 45) | def load_example_to_device(example,
  function build_network (line 64) | def build_network(model_cfg, measure_time=False, testing=False):
  function _worker_init_fn (line 70) | def _worker_init_fn(worker_id):
  function freeze_params (line 77) | def freeze_params(params: dict, include: str = None, exclude: str = None):
  function freeze_params_v2 (line 97) | def freeze_params_v2(params: dict, include: str = None, exclude: str = N...
  function filter_param_dict (line 114) | def filter_param_dict(state_dict: dict, include: str = None, exclude: st...
  function chk_rank (line 133) | def chk_rank(rank_, use_dist=False):
  function get_rank (line 143) | def get_rank(use_dist=False):
  function get_world (line 153) | def get_world(use_dist):
  function get_ngpus_per_node (line 161) | def get_ngpus_per_node():
  function get_logger (line 164) | def get_logger():
  function multi_proc_train (line 176) | def multi_proc_train(
  function train_worker (line 242) | def train_worker(rank, params):
  function train (line 279) | def train(
  function eval_once (line 666) | def eval_once(net,

FILE: tools/transform_data_format.py
  function range_to_depth (line 19) | def range_to_depth(mask, range, K):
  function crop (line 36) | def crop(image, depth, mask, K_old, margin_ratio=0.1, output_size=128 ):
  class DataFormatter (line 87) | class DataFormatter(object):
    method __init__ (line 88) | def __init__(self, data_type, data_info_path, crop_param=None ):
    method process (line 97) | def process(self, data_root,depth_root, save_root):
    method _proc_LM_SYN_PVNET (line 110) | def _proc_LM_SYN_PVNET(self, data_info, data_root, save_root):
    method _proc_LM_SYN_PVNET_LMK (line 162) | def _proc_LM_SYN_PVNET_LMK(self, data_info, data_root, save_root):
    method _proc_LM_FUSE_PVNET (line 216) | def _proc_LM_FUSE_PVNET(self, data_info, data_root, depth_root, save_r...
    method _proc_LM_FUSE_SINGLE_PVNET (line 327) | def _proc_LM_FUSE_SINGLE_PVNET(self, data_info, data_root, depth_root,...
  function run (line 438) | def run(data_type,data_info_path, image_root, depth_root, save_dir, crop...

FILE: torchplus/metrics.py
  class Scalar (line 7) | class Scalar(nn.Module):
    method __init__ (line 8) | def __init__(self):
    method forward (line 13) | def forward(self, scalar):
    method value (line 20) | def value(self):
    method clear (line 23) | def clear(self):
  class Accuracy (line 27) | class Accuracy(nn.Module):
    method __init__ (line 28) | def __init__(self,
    method forward (line 41) | def forward(self, labels, preds, weights=None):
    method value (line 68) | def value(self):
    method clear (line 71) | def clear(self):
  class Precision (line 76) | class Precision(nn.Module):
    method __init__ (line 77) | def __init__(self, dim=1, ignore_idx=-1, threshold=0.5):
    method forward (line 85) | def forward(self, labels, preds, weights=None):
    method value (line 119) | def value(self):
    method clear (line 121) | def clear(self):
  class Recall (line 126) | class Recall(nn.Module):
    method __init__ (line 127) | def __init__(self, dim=1, ignore_idx=-1, threshold=0.5):
    method forward (line 135) | def forward(self, labels, preds, weights=None):
    method value (line 167) | def value(self):
    method clear (line 169) | def clear(self):
  function _calc_binary_metrics (line 174) | def _calc_binary_metrics(labels,
  class PrecisionRecall (line 195) | class PrecisionRecall(nn.Module):
    method __init__ (line 196) | def __init__(self,
    method forward (line 221) | def forward(self, labels, preds, weights=None):
    method value (line 267) | def value(self):
    method thresholds (line 274) | def thresholds(self):
    method clear (line 277) | def clear(self):

FILE: torchplus/nn/functional.py
  function one_hot (line 3) | def one_hot(tensor, depth, dim=-1, on_value=1.0, dtype=torch.float32):

FILE: torchplus/nn/modules/common.py
  class Empty (line 8) | class Empty(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, *args, **kwargs):
    method forward (line 13) | def forward(self, *args, **kwargs):
  class Sequential (line 21) | class Sequential(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, *args, **kwargs):
    method __getitem__ (line 68) | def __getitem__(self, idx):
    method __len__ (line 78) | def __len__(self):
    method add (line 81) | def add(self, module, name=None):
    method forward (line 88) | def forward(self, input):

FILE: torchplus/nn/modules/normalization.py
  class GroupNorm (line 4) | class GroupNorm(torch.nn.GroupNorm):
    method __init__ (line 5) | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):

FILE: torchplus/ops/array_ops.py
  function scatter_nd (line 8) | def scatter_nd(indices, updates, shape):
  function gather_nd (line 24) | def gather_nd(params, indices):
  function roll (line 34) | def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[...

FILE: torchplus/tools.py
  function get_pos_to_kw_map (line 11) | def get_pos_to_kw_map(func):
  function get_kw_to_default_map (line 22) | def get_kw_to_default_map(func):
  function change_default_args (line 47) | def change_default_args(**kwargs):
  function torch_to_np_dtype (line 61) | def torch_to_np_dtype(ttype):

FILE: torchplus/train/checkpoint.py
  class DelayedKeyboardInterrupt (line 10) | class DelayedKeyboardInterrupt(object):
    method __enter__ (line 11) | def __enter__(self):
    method handler (line 15) | def handler(self, sig, frame):
    method __exit__ (line 19) | def __exit__(self, type, value, traceback):
  function latest_checkpoint (line 25) | def latest_checkpoint(model_dir, model_name):
  function _ordered_unique (line 49) | def _ordered_unique(seq):
  function save (line 54) | def save(model_dir,
  function restore (line 118) | def restore(ckpt_path, model, map_func=None, map_location='cpu'):
  function _check_model_names (line 129) | def _check_model_names(models):
  function _get_name_to_model_map (line 140) | def _get_name_to_model_map(models):
  function try_restore_latest_checkpoints (line 149) | def try_restore_latest_checkpoints(model_dir, models, map_func=None, map...
  function restore_latest_checkpoints (line 157) | def restore_latest_checkpoints(model_dir, models, map_func=None,  map_lo...
  function restore_models (line 167) | def restore_models(model_dir, models, global_step, map_func=None, map_lo...
  function save_models (line 175) | def save_models(model_dir,
  function gpu_to_cpu (line 186) | def gpu_to_cpu(models):
  function cpu_to_gpu (line 197) | def cpu_to_gpu(models):
  function save_models_cpu (line 208) | def save_models_cpu(model_dir,

FILE: torchplus/train/common.py
  function create_folder (line 5) | def create_folder(prefix, add_time=True, add_str=None, delete=False):

FILE: torchplus/train/fastai_optim.py
  function split_bn_bias (line 14) | def split_bn_bias(layer_groups):
  function get_master (line 28) | def get_master(layer_groups, flat_master: bool = False):
  function model_g2master_g (line 55) | def model_g2master_g(model_params, master_params,
  function master2model (line 75) | def master2model(model_params, master_params,
  function listify (line 92) | def listify(p=None, q=None):
  function trainable_params (line 107) | def trainable_params(m: nn.Module):
  function is_tuple (line 114) | def is_tuple(x) -> bool:
  class OptimWrapper (line 119) | class OptimWrapper(torch.optim.Optimizer):
    method __init__ (line 122) | def __init__(self, opt, wd, true_wd: bool = False, bn_wd: bool = True):
    method create (line 131) | def create(cls, opt_func, lr, layer_groups, **kwargs):
    method new (line 161) | def new(self, layer_groups):
    method __repr__ (line 177) | def __repr__(self) -> str:
    method step (line 181) | def step(self) -> None:
    method zero_grad (line 196) | def zero_grad(self) -> None:
    method __getstate__ (line 201) | def __getstate__(self):
    method __setstate__ (line 204) | def __setstate__(self, state):
    method state_dict (line 207) | def state_dict(self):
    method load_state_dict (line 210) | def load_state_dict(self, state_dict):
    method add_param_group (line 213) | def add_param_group(self, param_group):
    method clear (line 216) | def clear(self):
    method param_groups (line 223) | def param_groups(self):
    method defaults (line 227) | def defaults(self):
    method state (line 231) | def state(self):
    method lr (line 236) | def lr(self) -> float:
    method lr (line 240) | def lr(self, val: float) -> None:
    method mom (line 244) | def mom(self) -> float:
    method mom (line 248) | def mom(self, val: float) -> None:
    method beta (line 256) | def beta(self) -> float:
    method beta (line 260) | def beta(self, val: float) -> None:
    method wd (line 271) | def wd(self) -> float:
    method wd (line 275) | def wd(self, val: float) -> None:
    method read_defaults (line 283) | def read_defaults(self) -> None:
    method set_val (line 297) | def set_val(self, key: str, val, bn_groups: bool = True):
    method read_val (line 308) | def read_val(self, key: str):
  class FastAIMixedOptim (line 316) | class FastAIMixedOptim(OptimWrapper):
    method create (line 318) | def create(cls,
    method step (line 345) | def step(self):

FILE: torchplus/train/learning_schedules.py
  class _LRSchedulerStep (line 6) | class _LRSchedulerStep(object):
    method __init__ (line 7) | def __init__(self, optimizer, last_step=-1):
    method get_lr (line 32) | def get_lr(self):
    method _get_lr_per_group (line 36) | def _get_lr_per_group(self, base_lr):
    method step (line 39) | def step(self, step=None):
  class Constant (line 47) | class Constant(_LRSchedulerStep):
    method __init__ (line 48) | def __init__(self, optimizer, last_step=-1):
    method _get_lr_per_group (line 51) | def _get_lr_per_group(self, base_lr):
  class ManualStepping (line 55) | class ManualStepping(_LRSchedulerStep):
    method __init__ (line 60) | def __init__(self, optimizer, boundaries, rates, last_step=-1):
    method _get_lr_per_group (line 79) | def _get_lr_per_group(self, base_lr):
  class ExponentialDecayWithBurnin (line 90) | class ExponentialDecayWithBurnin(_LRSchedulerStep):
    method __init__ (line 94) | def __init__(self,
    method _get_lr_per_group (line 108) | def _get_lr_per_group(self, base_lr):
  class ExponentialDecay (line 120) | class ExponentialDecay(_LRSchedulerStep):
    method __init__ (line 121) | def __init__(self,
    method _get_lr_per_group (line 133) | def _get_lr_per_group(self, base_lr):
  class CosineDecayWithWarmup (line 145) | class CosineDecayWithWarmup(_LRSchedulerStep):
    method __init__ (line 146) | def __init__(self,
    method _get_lr_per_group (line 161) | def _get_lr_per_group(self, base_lr):
  class OneCycle (line 181) | class OneCycle(_LRSchedulerStep):
    method __init__ (line 182) | def __init__(self,
    method _get_lr_per_group (line 202) | def _get_lr_per_group(self, base_lr):

FILE: torchplus/train/learning_schedules_fastai.py
  class LRSchedulerStep (line 7) | class LRSchedulerStep(object):
    method __init__ (line 8) | def __init__(self, fai_optimizer, total_step, lr_phases, mom_phases):
    method step (line 43) | def step(self, step):
    method learning_rate (line 64) | def learning_rate(self):
  function annealing_cos (line 68) | def annealing_cos(start, end, pct):
  class OneCycle (line 75) | class OneCycle(LRSchedulerStep):
    method __init__ (line 76) | def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor,
  class ExponentialDecayWarmup (line 97) | class ExponentialDecayWarmup(LRSchedulerStep):
    method __init__ (line 98) | def __init__(self,
  class ExponentialDecay (line 135) | class ExponentialDecay(LRSchedulerStep):
    method __init__ (line 136) | def __init__(self,
  class ManualStepping (line 167) | class ManualStepping(LRSchedulerStep):
    method __init__ (line 168) | def __init__(self, fai_optimizer, total_step, boundaries, rates):
  class FakeOptim (line 179) | class FakeOptim:
    method __init__ (line 180) | def __init__(self):

FILE: torchplus/train/optim.py
  function param_fp32_copy (line 10) | def param_fp32_copy(params):
  function set_grad (line 18) | def set_grad(params, params_with_grad, scale=1.0):
  class MixedPrecisionWrapper (line 31) | class MixedPrecisionWrapper(object):
    method __init__ (line 42) | def __init__(self,
    method __getstate__ (line 70) | def __getstate__(self):
    method __setstate__ (line 73) | def __setstate__(self, state):
    method __repr__ (line 76) | def __repr__(self):
    method state_dict (line 79) | def state_dict(self):
    method load_state_dict (line 82) | def load_state_dict(self, state_dict):
    method zero_grad (line 85) | def zero_grad(self):
    method step (line 88) | def step(self, closure=None):

FILE: utils/config_io.py
  function mkdir_if_not_exists (line 8) | def mkdir_if_not_exists(path):
  function read_yaml (line 17) | def read_yaml(filename):
  function copy_file (line 31) | def copy_file(src_file, tgt_file):
  function update_dict (line 40) | def update_dict(dict1, dict2, intersection=False):
  function merge_cfg (line 68) | def merge_cfg(cfg_files, intersection=False):
  function write_cfg (line 88) | def write_cfg(default, custom, f, level_cnt=0):
  function save_cfg (line 123) | def save_cfg(cfg_files, file_path):

FILE: utils/distributed_utils.py
  class ParallelWrapper (line 12) | class ParallelWrapper(Module):
    method __init__ (line 13) | def __init__(self, net, parallel_mode='none'):
    method forward (line 27) | def forward(self, *inputs, **kwargs):
    method train (line 30) | def train(self, mode=True):
  class DistModule (line 35) | class DistModule(Module):
    method __init__ (line 36) | def __init__(self, module):
    method forward (line 41) | def forward(self, *inputs, **kwargs):
    method train (line 44) | def train(self, mode=True):
  function gradients_multiply (line 48) | def gradients_multiply(model, multiplier=1):
  function average_gradients (line 53) | def average_gradients(model):
  function broadcast_params (line 68) | def broadcast_params(model):
  function dist_init (line 74) | def dist_init(port):
  class DistributedSequatialSampler (line 117) | class DistributedSequatialSampler(Sampler):
    method __init__ (line 135) | def __init__(self, dataset, num_replicas=None, rank=None):
    method __iter__ (line 154) | def __iter__(self):
    method __len__ (line 171) | def __len__(self):
    method set_epoch (line 174) | def set_epoch(self, epoch):
  class DistributedGivenIterationSampler (line 178) | class DistributedGivenIterationSampler(Sampler):
    method __init__ (line 179) | def __init__(self, dataset, total_iter, batch_size, world_size=None, r...
    method __iter__ (line 197) | def __iter__(self):
    method gen_new_list (line 206) | def gen_new_list(self):
    method __len__ (line 227) | def __len__(self):
    method set_epoch (line 234) | def set_epoch(self, epoch):
  class DistributedGivenIterationSamplerEpoch (line 238) | class DistributedGivenIterationSamplerEpoch(Sampler):
    method __init__ (line 239) | def __init__(self, dataset, total_iter, batch_size, world_size=None, r...
    method __iter__ (line 257) | def __iter__(self):
    method gen_new_list (line 266) | def gen_new_list(self):
    method __len__ (line 306) | def __len__(self):
    method set_epoch (line 313) | def set_epoch(self, epoch):

FILE: utils/eval_metric.py
  function get_ply_model (line 18) | def get_ply_model(model_path, scale=1):
  function project (line 28) | def project(xyz, K, RT):
  function find_nearest_point_idx (line 40) | def find_nearest_point_idx(ref_pts, que_pts):
  class LineMODEvaluator (line 59) | class LineMODEvaluator:
    method __init__ (line 60) | def __init__(self, class_name, result_dir, icp_refine=False):
    method projection_2d (line 102) | def projection_2d(self, pose_pred, pose_targets, K, icp=False, thresho...
    method projection_2d_sym (line 112) | def projection_2d_sym(self, pose_pred, pose_targets, K, threshold=5):
    method add2_metric (line 120) | def add2_metric(self, pose_pred, pose_targets, icp=False, syn=False, p...
    method add5_metric (line 140) | def add5_metric(self, pose_pred, pose_targets, icp=False, syn=False, p...
    method add_metric (line 161) | def add_metric(self, pose_pred, pose_targets, icp=False, syn=False, pe...
    method cm_degree_5_metric (line 181) | def cm_degree_5_metric(self, pose_pred, pose_targets, icp=False):
    method mask_iou (line 194) | def mask_iou(self, output, batch):
    method icp_refine (line 201) | def icp_refine(self, pose_pred, anno, output, K):
    method icp_refine_ (line 220) | def icp_refine_(self, pose, anno, output):
    method summarize (line 261) | def summarize(self):
    method evaluate_rnnpose (line 306) | def evaluate_rnnpose(self, preds_dict, example): # sample_corresponden...

FILE: utils/furthest_point_sample.py
  function fragmentation_fps (line 6) | def fragmentation_fps(vertices, num_frags):

FILE: utils/geometric.py
  function range_to_depth (line 4) | def range_to_depth(mask, range, K):
  function mask_depth_to_point_cloud (line 22) | def mask_depth_to_point_cloud(mask,depth,K):
  function chordal_distance (line 36) | def chordal_distance(R1,R2):
  function rotation_angle (line 39) | def rotation_angle(R1, R2):
  function render_pointcloud (line 42) | def render_pointcloud(pc, T, K, render_image_size):

FILE: utils/img_utils.py
  function read_depth (line 9) | def read_depth(path):
  function unnormalize_img (line 19) | def unnormalize_img(img, mean, std, in_gpu=True):
  function draw_seg_th (line 32) | def draw_seg_th(seg, num_cls=-1):
  function draw_seg_prob_th (line 50) | def draw_seg_prob_th(seg_prob):
  function draw_vertex_th (line 59) | def draw_vertex_th(vertex):
  function visualize_coco_bbox (line 70) | def visualize_coco_bbox(img, boxes):
  function visualize_heatmap (line 84) | def visualize_heatmap(img, hm):
  function visualize_coco_img_mask (line 101) | def visualize_coco_img_mask(img, mask):
  function visualize_color_aug (line 108) | def visualize_color_aug(orig_img, aug_img):
  function visualize_coco_ann (line 115) | def visualize_coco_ann(coco, img, ann):
  function bgr_to_rgb (line 121) | def bgr_to_rgb(img):

FILE: utils/log_tool.py
  function _flat_nested_json_dict (line 9) | def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""):
  function flat_nested_json_dict (line 17) | def flat_nested_json_dict(json_dict, sep=".") -> dict:
  function metric_to_str (line 29) | def metric_to_str(metrics, sep='.'):
  class SimpleModelLog (line 46) | class SimpleModelLog:
    method __init__ (line 57) | def __init__(self, model_dir, disable=False):
    method open (line 70) | def open(self):
    method close (line 89) | def close(self):
    method log_text (line 102) | def log_text(self, text, step, tag="regular log"):
    method log_metrics (line 119) | def log_metrics(self, metrics: dict, step):
    method log_images (line 140) | def log_images(self, images: dict, step, prefix=''):
    method log_histograms (line 147) | def log_histograms(self, vals: dict, step, prefix=''):

FILE: utils/pose_utils.py
  function pose_padding (line 22) | def pose_padding(P):
  function vdot (line 35) | def vdot(v1, v2):
  function normalize (line 47) | def normalize(x, p=2, dim=0):
  function qmult (line 60) | def qmult(q1, q2):
  function qinv (line 81) | def qinv(q):
  function qexp_t (line 91) | def qexp_t(q):
  function qlog_t (line 105) | def qlog_t(q):
  function qexp_t_safe (line 118) | def qexp_t_safe(q):
  function qlog_t_safe (line 130) | def qlog_t_safe(q):
  function rotate_vec_by_q (line 142) | def rotate_vec_by_q(t, q):
  function compose_pose_quaternion (line 157) | def compose_pose_quaternion(p1, p2):
  function invert_pose_quaternion (line 173) | def invert_pose_quaternion(p):
  function calc_vo (line 185) | def calc_vo(p0, p1):
  function calc_vo_logq (line 195) | def calc_vo_logq(p0, p1):
  function calc_vo_relative (line 210) | def calc_vo_relative(p0, p1):
  function calc_vo_relative_logq (line 221) | def calc_vo_relative_logq(p0, p1):
  function calc_vo_relative_logq_safe (line 236) | def calc_vo_relative_logq_safe(p0, p1):
  function calc_vo_logq_safe (line 251) | def calc_vo_logq_safe(p0, p1):
  function calc_vos_simple (line 267) | def calc_vos_simple(poses):
  function calc_vos (line 283) | def calc_vos(poses):
  function calc_vos_relative (line 298) | def calc_vos_relative(poses):
  function calc_vos_safe (line 313) | def calc_vos_safe(poses):
  function calc_vos_safe_fc (line 328) | def calc_vos_safe_fc(poses):
  function qlog (line 348) | def qlog(q):
  function qexp (line 361) | def qexp(q):
  function process_poses (line 372) | def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
  function log_quaternion_angular_error (line 403) | def log_quaternion_angular_error(q1, q2):
  function quaternion_angular_error (line 407) | def quaternion_angular_error(q1, q2):
  function skew (line 420) | def skew(x):
  function dpq_q (line 430) | def dpq_q(p):
  function dpsq_q (line 444) | def dpsq_q(p):
  function dpsq_p (line 458) | def dpsq_p(q):
  function dqstq_q (line 472) | def dqstq_q(q, t):
  function dqstq_t (line 487) | def dqstq_t(q):
  function m_rot (line 498) | def m_rot(x):
  class PoseGraph (line 512) | class PoseGraph:
    method __init__ (line 513) | def __init__(self):
    method jacobian (line 522) | def jacobian(self, L_ax, L_aq, L_rx, L_rq):
    method residuals (line 564) | def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):
    method update_on_manifold (line 605) | def update_on_manifold(self, x):
    method optimize (line 630) | def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):
  class PoseGraphFC (line 669) | class PoseGraphFC:
    method __init__ (line 670) | def __init__(self):
    method jacobian (line 680) | def jacobian(self, L_ax, L_aq, L_rx, L_rq):
    method residuals (line 723) | def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):
    method update_on_manifold (line 767) | def update_on_manifold(self, x):
    method optimize (line 792) | def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):
  function optimize_poses (line 831) | def optimize_poses(pred_poses, vos=None, fc_vos=False, target_poses=None,
  function align_3d_pts (line 865) | def align_3d_pts(x1, x2):
  function align_2d_pts (line 914) | def align_2d_pts(x1, x2):
  function align_3d_pts_noscale (line 963) | def align_3d_pts_noscale(x1, x2):
  function align_2d_pts_noscale (line 1012) | def align_2d_pts_noscale(x1, x2):
  function align_camera_poses (line 1062) | def align_camera_poses(o1, o2, R1, R2, use_rotation_constraint=True):
  function test_align_3d_pts (line 1137) | def test_align_3d_pts():
  function test_align_camera_poses (line 1158) | def test_align_camera_poses():
  function pgo_test_poses (line 1189) | def pgo_test_poses():
  function pgo_test_poses1 (line 1214) | def pgo_test_poses1():
  function print_poses (line 1240) | def print_poses(poses):
  function test_pgo (line 1249) | def test_pgo():
  function test_pose_utils (line 1267) | def test_pose_utils():
  function test_q_error (line 1327) | def test_q_error():
  function test_log_q_error (line 1340) | def test_log_q_error():

FILE: utils/pose_utils_np.py
  function tq2RT (line 24) | def tq2RT(poses, square=False):
  function RT2tq (line 41) | def RT2tq(poses, square=False):
  function pose_interp (line 58) | def pose_interp(poses, timestamps_in, timestamps_out, r_interp='slerp'):
  function vdot (line 114) | def vdot(v1, v2):
  function normalize (line 127) | def normalize(x, p=2, dim=0, eps=1e-6):
  function qmult (line 144) | def qmult(q1, q2):
  function qinv (line 166) | def qinv(q):
  function qexp_t (line 177) | def qexp_t(q):
  function qlog_t (line 191) | def qlog_t(q):
  function qexp_t_safe (line 204) | def qexp_t_safe(q):
  function qlog_t_safe (line 216) | def qlog_t_safe(q):
  function rotate_vec_by_q (line 228) | def rotate_vec_by_q(t, q):
  function compose_pose_quaternion (line 246) | def compose_pose_quaternion(p1, p2):
  function invert_pose_quaternion (line 264) | def invert_pose_quaternion(p):
  function calc_vo (line 276) | def calc_vo(p0, p1):
  function calc_vo_logq (line 286) | def calc_vo_logq(p0, p1):
  function calc_vo_relative (line 301) | def calc_vo_relative(p0, p1):
  function calc_vo_relative_logq (line 312) | def calc_vo_relative_logq(p0, p1):
  function calc_vo_relative_logq_safe (line 327) | def calc_vo_relative_logq_safe(p0, p1):
  function calc_vo_logq_safe (line 342) | def calc_vo_logq_safe(p0, p1):
  function calc_vos_simple (line 358) | def calc_vos_simple(poses):
  function calc_vos (line 374) | def calc_vos(poses):
  function calc_vos_relative (line 389) | def calc_vos_relative(poses):
  function calc_vos_safe (line 404) | def calc_vos_safe(poses):
  function calc_vos_safe_fc (line 419) | def calc_vos_safe_fc(poses):
  function qlog (line 439) | def qlog(q):
  function qexp (line 452) | def qexp(q):
  function process_poses (line 463) | def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
  function log_quaternion_angular_error (line 494) | def log_quaternion_angular_error(q1, q2):
  function quaternion_angular_error (line 498) | def quaternion_angular_error(q1, q2):
  function skew (line 511) | def skew(x):
  function dpq_q (line 521) | def dpq_q(p):
  function dpsq_q (line 535) | def dpsq_q(p):
  function dpsq_p (line 549) | def dpsq_p(q):
  function dqstq_q (line 563) | def dqstq_q(q, t):
  function dqstq_t (line 578) | def dqstq_t(q):
  function m_rot (line 589) | def m_rot(x):
  class PoseGraph (line 603) | class PoseGraph:
    method __init__ (line 604) | def __init__(self):
    method jacobian (line 613) | def jacobian(self, L_ax, L_aq, L_rx, L_rq):
    method residuals (line 655) | def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):
    method update_on_manifold (line 696) | def update_on_manifold(self, x):
    method optimize (line 721) | def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):
  class PoseGraphFC (line 760) | class PoseGraphFC:
    method __init__ (line 761) | def __init__(self):
    method jacobian (line 771) | def jacobian(self, L_ax, L_aq, L_rx, L_rq):
    method residuals (line 814) | def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):
    method update_on_manifold (line 858) | def update_on_manifold(self, x):
    method optimize (line 883) | def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):
  function optimize_poses (line 922) | def optimize_poses(pred_poses, vos=None, fc_vos=False, target_poses=None,
  function align_3d_pts (line 956) | def align_3d_pts(x1, x2):
  function align_2d_pts (line 1005) | def align_2d_pts(x1, x2):
  function align_3d_pts_noscale (line 1054) | def align_3d_pts_noscale(x1, x2):
  function align_2d_pts_noscale (line 1103) | def align_2d_pts_noscale(x1, x2):
  function align_camera_poses (line 1153) | def align_camera_poses(o1, o2, R1, R2, use_rotation_constraint=True):
  function test_align_3d_pts (line 1228) | def test_align_3d_pts():
  function test_align_camera_poses (line 1249) | def test_align_camera_poses():
  function pgo_test_poses (line 1280) | def pgo_test_poses():
  function pgo_test_poses1 (line 1305) | def pgo_test_poses1():
  function print_poses (line 1331) | def print_poses(poses):
  function test_pgo (line 1340) | def test_pgo():
  function test_pose_utils (line 1358) | def test_pose_utils():
  function test_q_error (line 1418) | def test_q_error():
  function test_log_q_error (line 1431) | def test_log_q_error():

FILE: utils/progress_bar.py
  function progress_str (line 9) | def progress_str(val, *args, width=20, with_ptg=True):
  function second_to_time_str (line 29) | def second_to_time_str(second, omit_hours_if_possible=True):
  function progress_bar_iter (line 39) | def progress_bar_iter(task_list, width=20, with_ptg=True, step_time_aver...
  function enumerate_bar (line 67) | def enumerate_bar(task_list, width=20, with_ptg=True, step_time_average=...
  function max_point_str (line 93) | def max_point_str(val, max_point):
  class Unit (line 107) | class Unit(enum.Enum):
  function convert_size (line 112) | def convert_size(size_bytes):
  class ProgressBar (line 123) | class ProgressBar:
    method __init__ (line 124) | def __init__(self,
    method start (line 137) | def start(self, total_size):
    method print_bar (line 146) | def print_bar(self, finished_size=1, pre_string=None, post_string=None):

FILE: utils/rand_utils.py
  function truncated_normal (line 3) | def truncated_normal(u, sigma, min, max, shape=None):

FILE: utils/singleton.py
  class Singleton (line 3) | class Singleton(type):
    method __call__ (line 5) | def __call__(cls, *args, **kwargs):

FILE: utils/timer.py
  function simple_timer (line 6) | def simple_timer(name=''):
  function singleton (line 12) | def singleton(class_):
  class timming (line 21) | class timming(object):
    method __init__ (line 22) | def __init__(self):
    method start (line 26) | def start(self, item_name):
    method end (line 46) | def end(self, item_name):
    method summarize (line 57) | def summarize(self):

FILE: utils/util.py
  function freeze_params (line 7) | def freeze_params(params: dict, include: str = None, exclude: str = None):
  function freeze_params_v2 (line 27) | def freeze_params_v2(params: dict, include: str = None, exclude: str = N...
  function filter_param_dict (line 44) | def filter_param_dict(state_dict: dict, include: str = None, exclude: st...
  function modify_parameter_name_with_map (line 63) | def modify_parameter_name_with_map(state_dict, parameteter_name_map=None):
  function load_pretrained_model_map_func (line 73) | def load_pretrained_model_map_func(state_dict,parameteter_name_map = Non...
  function list_recursive_op (line 79) | def list_recursive_op(input_list, op):
  function dict_recursive_op (line 93) | def dict_recursive_op(input_dict, op):

FILE: utils/visualize.py
  function vis_pointclouds_cv2 (line 5) | def vis_pointclouds_cv2(pc, K, win_size, init_transform=None, color=None...
  function vis_2d_keypoints_cv2 (line 30) | def vis_2d_keypoints_cv2(img, keypoints, color=None):
  function get_model_corners (line 45) | def get_model_corners(model):
  function vis_pose_box (line 61) | def vis_pose_box(RT,K, model, background=None,fig=None, ax=None, title='...
Condensed preview — 126 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (861K chars).
[
  {
    "path": ".gitignore",
    "chars": 2563,
    "preview": "**/*.old\n**/*.bak\n\n.DS_Store\n# Created by https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks"
  },
  {
    "path": "LICENSE.md",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 6577,
    "preview": "# RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization\n\n[Ya"
  },
  {
    "path": "builder/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "builder/dataset_builder.py",
    "chars": 1766,
    "preview": "\nfrom data.dataset import get_dataset_class\nimport numpy as np\nfrom functools import partial\nfrom data.preprocess import"
  },
  {
    "path": "builder/input_reader_builder.py",
    "chars": 663,
    "preview": "\nfrom torch.utils.data import Dataset\n\nfrom builder import dataset_builder\n\n\nclass DatasetWrapper(Dataset):\n    \"\"\" conv"
  },
  {
    "path": "builder/losses_builder.py",
    "chars": 257,
    "preview": "\nfrom model import losses\n\ndef build(loss_config):\n\n    criterions = {}\n  \n    criterions[\"metric_loss\"] =losses.MetricL"
  },
  {
    "path": "builder/lr_scheduler_builder.py",
    "chars": 3404,
    "preview": "\nfrom torchplus.train import learning_schedules_fastai as lsf\nimport torch\nimport numpy as np \n\ndef build(optimizer_conf"
  },
  {
    "path": "builder/optimizer_builder.py",
    "chars": 3830,
    "preview": "from torchplus.train import learning_schedules\nfrom torchplus.train import optim\nimport torch\nfrom torch import nn\nfrom "
  },
  {
    "path": "builder/rnnpose_builder.py",
    "chars": 397,
    "preview": "from builder import losses_builder\nfrom model.RNNPose import get_posenet_class\nimport model.RNNPose\n\n\ndef build(model_cf"
  },
  {
    "path": "config/default.py",
    "chars": 2841,
    "preview": "from yacs.config import CfgNode as CN\nfrom utils.singleton import Singleton\nimport os\n\ndef _merge_a_into_b(a, b):\n    \"\""
  },
  {
    "path": "config/linemod/copy.sh",
    "chars": 315,
    "preview": "declare -a arr=(\"glue\" \"ape\" \"cat\" \"phone\" \"eggbox\" \"benchvise\" \"lamp\" \"camera\" \"can\" \"driller\" \"duck\" \"holepuncher\" \"ir"
  },
  {
    "path": "config/linemod/copy_occ.sh",
    "chars": 305,
    "preview": "declare -a arr=(\"glue\" \"ape\" \"cat\" \"phone\" \"eggbox\" \"benchvise\" \"lamp\" \"camera\" \"can\" \"driller\" \"duck\" \"holepuncher\" \"ir"
  },
  {
    "path": "config/linemod/template_fw0.5.yml",
    "chars": 5090,
    "preview": "vars:\n  input_h: &input_h\n    320 \n  input_w: &input_w\n    320 \n  batch_size: &batch_size\n    1\n  descriptor_dim: &descr"
  },
  {
    "path": "config/linemod/template_fw0.5_occ.yml",
    "chars": 5131,
    "preview": "vars:\n  input_h: &input_h\n    320 \n  input_w: &input_w\n    320 \n  batch_size: &batch_size\n    1\n  descriptor_dim: &descr"
  },
  {
    "path": "data/__init__.py",
    "chars": 51,
    "preview": "from . import dataset\nfrom . import linemod_dataset"
  },
  {
    "path": "data/dataset.py",
    "chars": 1021,
    "preview": "import pathlib\nimport pickle\nimport time\nfrom functools import partial\n\nimport numpy as np\n\n\nREGISTERED_DATASET_CLASSES "
  },
  {
    "path": "data/linemod/linemod_config.py",
    "chars": 767,
    "preview": "import numpy as np\ndiameters = {\n    'cat': 15.2633,\n    'ape': 9.74298,\n    'benchvise': 28.6908,\n    'bowl': 17.1185,\n"
  },
  {
    "path": "data/linemod_dataset.py",
    "chars": 18580,
    "preview": "import numpy as np \nimport random\nimport os \nfrom data.dataset import Dataset, register_dataset\nimport pickle\nimport PIL"
  },
  {
    "path": "data/preprocess.py",
    "chars": 39523,
    "preview": "import open3d as o3d\nimport copy\nimport os\n\nimport pathlib\nimport pickle\nimport time\nfrom collections import defaultdict"
  },
  {
    "path": "data/transforms.py",
    "chars": 2592,
    "preview": "import numpy as np\nimport random\nimport torch\nimport torchvision\nfrom torchvision.transforms import functional as F\nimpo"
  },
  {
    "path": "data/ycb/basic.py",
    "chars": 785,
    "preview": "import mmcv \nbop_ycb_idx2class={\n        1: '002_master_chef_can', \n        2: '003_cracker_box',\n        3: '004_sugar_"
  },
  {
    "path": "doc/prepare_data.md",
    "chars": 3052,
    "preview": "# Data Preparation Tips\nAll the related data for data preparation can be downloaded [here](https://mycuhk-my.sharepoint."
  },
  {
    "path": "docker/Dockerfile",
    "chars": 3089,
    "preview": "FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04\n\nRUN apt-key del 7fa2af80\nRUN apt-key adv --fetch-keys http://developer.d"
  },
  {
    "path": "docker/freeze.yml",
    "chars": 6866,
    "preview": "name: py37_tmp\nchannels:\n  - pytorch\n  - pytorch3d\n  - open3d-admin\n  - bottler\n  - iopath\n  - fvcore\n  - conda-forge\n  "
  },
  {
    "path": "geometry/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "geometry/cholesky.py",
    "chars": 1655,
    "preview": "# import tensorflow as tf\nimport torch #as tf\nimport numpy as np\n# from utils.einsum import einsum\nfrom torch import ein"
  },
  {
    "path": "geometry/diff_render.py",
    "chars": 14051,
    "preview": "import torch\nimport torch.nn as nn \nimport torch.nn.functional as F \n\nimport numpy \n\nfrom pytorch3d.renderer import (\n  "
  },
  {
    "path": "geometry/diff_render_optim.py",
    "chars": 18390,
    "preview": "## Speed optimized: sharing the rasterization among different rendering process\n\nimport torch\nimport torch.nn as nn \nimp"
  },
  {
    "path": "geometry/einsum.py",
    "chars": 2306,
    "preview": "# import tensorflow as torch\nimport torch as torch\n\nimport numpy as np\nimport re\nimport string\n\ndef einsum(equation, *in"
  },
  {
    "path": "geometry/intrinsics.py",
    "chars": 1902,
    "preview": "import torch\nimport numpy as np\n# from utils.einsum import einsum\nfrom .einsum import einsum\n\ndef intrinsics_vec_to_matr"
  },
  {
    "path": "geometry/projective_ops.py",
    "chars": 3744,
    "preview": "import numpy as np\nimport torch \n\n# from utils.einsum import einsum\nfrom torch import einsum\n\n\n# MIN_DEPTH = 0.1\nMIN_DEP"
  },
  {
    "path": "geometry/se3.py",
    "chars": 9345,
    "preview": "\"\"\"\nSO3 and SE3 operations, exponentials and logarithms adapted from Sophus\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom .e"
  },
  {
    "path": "geometry/transformation.py",
    "chars": 10700,
    "preview": "import torch  \nimport numpy as np\n\n# from core.config import cfg\nfrom config.default import get_cfg\nfrom .se3 import *\nf"
  },
  {
    "path": "model/CFNet.py",
    "chars": 6589,
    "preview": "\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom thirdparty.raft.update imp"
  },
  {
    "path": "model/HybridNet.py",
    "chars": 3495,
    "preview": "\nimport torch \nimport torch.nn as nn \n\nfrom thirdparty.kpconv.kpconv_blocks import *\nimport torch.nn.functional as F\nimp"
  },
  {
    "path": "model/PoseRefiner.py",
    "chars": 18193,
    "preview": "import os \nimport time\nimport cv2\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distri"
  },
  {
    "path": "model/RNNPose.py",
    "chars": 12088,
    "preview": "#\nimport time\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.distributed as dist\nim"
  },
  {
    "path": "model/descriptor2D.py",
    "chars": 6228,
    "preview": "from easydict import EasyDict as edict\nfrom pathlib import Path\nimport torch\nfrom torch import nn\nfrom torchplus.nn.modu"
  },
  {
    "path": "model/descriptor3D.py",
    "chars": 7155,
    "preview": "import torch \nimport torch.nn as nn \n\n\nfrom kpconv.kpconv_blocks import *\nimport torch.nn.functional as F\nimport numpy a"
  },
  {
    "path": "model/losses.py",
    "chars": 12574,
    "preview": "from sklearn.metrics import precision_recall_fscore_support\nfrom thirdparty.kpconv.lib.utils import square_distance\nfrom"
  },
  {
    "path": "scripts/compile_3rdparty.sh",
    "chars": 236,
    "preview": "#!/usr/bin/bash\n\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\n\ncd $SCRIPT_DIR/../thirdpar"
  },
  {
    "path": "scripts/eval.sh",
    "chars": 796,
    "preview": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH"
  },
  {
    "path": "scripts/eval_lmocc.sh",
    "chars": 828,
    "preview": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH"
  },
  {
    "path": "scripts/run_dataformatter.sh",
    "chars": 445,
    "preview": "SCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nPROJ_ROOT=$SCRIPT_DIR/../\n\npython $PROJ_ROOT"
  },
  {
    "path": "scripts/run_datainfo_generation.sh",
    "chars": 1141,
    "preview": "\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nPROJ_ROOT=$SCRIPT_DIR/../\nexport PYTHONPATH"
  },
  {
    "path": "scripts/train.sh",
    "chars": 801,
    "preview": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH"
  },
  {
    "path": "thirdparty/kpconv/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/compile_wrappers.sh",
    "chars": 185,
    "preview": "#!/bin/bash\n\n# Compile cpp subsampling\ncd cpp_subsampling\npython3 setup.py build_ext --inplace\ncd ..\n\n# Compile cpp neig"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/build.bat",
    "chars": 49,
    "preview": "@echo off\npy setup.py build_ext --inplace\n\n\npause"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp",
    "chars": 7454,
    "preview": "\n#include \"neighbors.h\"\n\n\nvoid brute_neighbors(vector<PointXYZ>& queries, vector<PointXYZ>& supports, vector<int>& neigh"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.h",
    "chars": 1032,
    "preview": "\n\n#include \"../../cpp_utils/cloud/cloud.h\"\n#include \"../../cpp_utils/nanoflann/nanoflann.hpp\"\n\n#include <set>\n#include <"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/setup.py",
    "chars": 619,
    "preview": "from distutils.core import setup, Extension\nimport numpy.distutils.misc_util\n\n# Adding OpenCV to project\n# *************"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/wrapper.cpp",
    "chars": 7178,
    "preview": "#include <Python.h>\n#include <numpy/arrayobject.h>\n#include \"neighbors/neighbors.h\"\n#include <string>\n\n\n\n// docstrings f"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/build.bat",
    "chars": 49,
    "preview": "@echo off\npy setup.py build_ext --inplace\n\n\npause"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp",
    "chars": 7076,
    "preview": "\n#include \"grid_subsampling.h\"\n\n\nvoid grid_subsampling(vector<PointXYZ>& original_points,\n                      vector<P"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h",
    "chars": 2464,
    "preview": "\n\n#include \"../../cpp_utils/cloud/cloud.h\"\n\n#include <set>\n#include <cstdint>\n\nusing namespace std;\n\nclass SampledData\n{"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/setup.py",
    "chars": 633,
    "preview": "from distutils.core import setup, Extension\nimport numpy.distutils.misc_util\n\n# Adding OpenCV to project\n# *************"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/wrapper.cpp",
    "chars": 17331,
    "preview": "#include <Python.h>\n#include <numpy/arrayobject.h>\n#include \"grid_subsampling/grid_subsampling.h\"\n#include <string>\n\n\n\n/"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.cpp",
    "chars": 962,
    "preview": "//\n//\n//\t\t0==========================0\n//\t\t|    Local feature test    |\n//\t\t0==========================0\n//\n//\t\tversion "
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.h",
    "chars": 3381,
    "preview": "//\n//\n//\t\t0==========================0\n//\t\t|    Local feature test    |\n//\t\t0==========================0\n//\n//\t\tversion "
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp",
    "chars": 73133,
    "preview": "/***********************************************************************\n * Software License Agreement (BSD License)\n *\n"
  },
  {
    "path": "thirdparty/kpconv/kernels/kernel_points.py",
    "chars": 17287,
    "preview": "\n#\n#\n#      0=================================0\n#      |    Kernel Point Convolutions    |\n#      0====================="
  },
  {
    "path": "thirdparty/kpconv/kpconv_blocks.py",
    "chars": 27117,
    "preview": "#\n#\n#      0=================================0\n#      |    Kernel Point Convolutions    |\n#      0======================"
  },
  {
    "path": "thirdparty/kpconv/lib/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "thirdparty/kpconv/lib/ply.py",
    "chars": 10301,
    "preview": "#\n#\n#      0===============================0\n#      |    PLY files reader/writer    |\n#      0=========================="
  },
  {
    "path": "thirdparty/kpconv/lib/timer.py",
    "chars": 1335,
    "preview": "import time\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__("
  },
  {
    "path": "thirdparty/kpconv/lib/utils.py",
    "chars": 2656,
    "preview": "\"\"\"\nGeneral utility functions\n\nAuthor: Shengyu Huang\nLast modified: 30.11.2020\n\"\"\"\n\nimport os,re,sys,json,yaml,random, a"
  },
  {
    "path": "thirdparty/nn/_ext.c",
    "chars": 22742,
    "preview": "#define _CFFI_\n\n/* We try to define Py_LIMITED_API before including Python.h.\n\n   Mess: we can only define it if Py_DEBU"
  },
  {
    "path": "thirdparty/nn/nn_utils.py",
    "chars": 763,
    "preview": "# from lib.csrc.nn._ext import lib, ffi\nfrom thirdparty.nn._ext import lib, ffi\nimport numpy as np\n\n\ndef find_nearest_po"
  },
  {
    "path": "thirdparty/nn/setup.py",
    "chars": 727,
    "preview": "import os\n\ncuda_include=os.path.join(os.environ.get('CUDA_HOME'), 'include')\nos.system('nvcc src/nearest_neighborhood.cu"
  },
  {
    "path": "thirdparty/nn/src/ext.h",
    "chars": 215,
    "preview": "void findNearestPointIdxLauncher(\n    float* ref_pts,   // [b,pn1,dim]\n    float* que_pts,   // [b,pn2,dim]\n    int* idx"
  },
  {
    "path": "thirdparty/nn/src/nearest_neighborhood.cu",
    "chars": 4312,
    "preview": "#include <float.h>\n#include <stdio.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_runtime_api.h>\n#include"
  },
  {
    "path": "thirdparty/raft/corr.py",
    "chars": 3556,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt"
  },
  {
    "path": "thirdparty/raft/extractor.py",
    "chars": 15444,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(se"
  },
  {
    "path": "thirdparty/raft/update.py",
    "chars": 8438,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, i"
  },
  {
    "path": "thirdparty/raft/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "thirdparty/raft/utils/augmentor.py",
    "chars": 9108,
    "preview": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL"
  },
  {
    "path": "thirdparty/raft/utils/flow_viz.py",
    "chars": 4318,
    "preview": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright "
  },
  {
    "path": "thirdparty/raft/utils/frame_utils.py",
    "chars": 4024,
    "preview": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUse"
  },
  {
    "path": "thirdparty/raft/utils/utils.py",
    "chars": 2643,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \""
  },
  {
    "path": "thirdparty/vsd/inout.py",
    "chars": 5934,
    "preview": "# Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)\n# Center for Machine Perception, Czech Technical University in Prague\n"
  },
  {
    "path": "tools/eval.py",
    "chars": 19686,
    "preview": "#CERTIFICATED\nimport torch\nimport numpy as np \nimport tensorboard\nfrom pathlib import Path\nimport json\nimport random\nimp"
  },
  {
    "path": "tools/generate_data_info_deepim_0_orig.py",
    "chars": 6736,
    "preview": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemo"
  },
  {
    "path": "tools/generate_data_info_deepim_1_syn.py",
    "chars": 7473,
    "preview": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemo"
  },
  {
    "path": "tools/generate_data_info_deepim_2_posecnnval.py",
    "chars": 7224,
    "preview": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemo"
  },
  {
    "path": "tools/generate_data_info_v2_deepim.py",
    "chars": 5748,
    "preview": "#The version compatible with deepim   \nimport os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nim"
  },
  {
    "path": "tools/train.py",
    "chars": 27902,
    "preview": "import numpy as np \nimport torch\n\nfrom pathlib import Path\nimport json\nimport random\nimport re\nimport torch.backends.cud"
  },
  {
    "path": "tools/transform_data_format.py",
    "chars": 19966,
    "preview": "import numpy as np \nimport cv2\nimport pickle\nimport fire\nimport os\nimport argparse\n\n\nlinemod_K = np.array([[572.4114, 0."
  },
  {
    "path": "torchplus/__init__.py",
    "chars": 190,
    "preview": "from . import train\r\nfrom . import nn\r\nfrom . import metrics\r\nfrom . import tools\r\n\r\nfrom .tools import change_default_a"
  },
  {
    "path": "torchplus/metrics.py",
    "chars": 10431,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass Scalar(nn.Module):\n    def "
  },
  {
    "path": "torchplus/nn/__init__.py",
    "chars": 159,
    "preview": "from torchplus.nn.functional import one_hot\nfrom torchplus.nn.modules.common import Empty, Sequential\nfrom torchplus.nn."
  },
  {
    "path": "torchplus/nn/functional.py",
    "chars": 286,
    "preview": "import torch\n\ndef one_hot(tensor, depth, dim=-1, on_value=1.0, dtype=torch.float32):\n    tensor_onehot = torch.zeros(\n  "
  },
  {
    "path": "torchplus/nn/modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "torchplus/nn/modules/common.py",
    "chars": 2934,
    "preview": "import sys\nfrom collections import OrderedDict\n\nimport torch\nfrom torch.nn import functional as F\n\n\nclass Empty(torch.nn"
  },
  {
    "path": "torchplus/nn/modules/normalization.py",
    "chars": 273,
    "preview": "import torch\n\n\nclass GroupNorm(torch.nn.GroupNorm):\n    def __init__(self, num_channels, num_groups, eps=1e-5, affine=Tr"
  },
  {
    "path": "torchplus/ops/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "torchplus/ops/array_ops.py",
    "chars": 1962,
    "preview": "import ctypes\nimport math\nimport time\nimport torch\nfrom typing import Optional\n\n\ndef scatter_nd(indices, updates, shape)"
  },
  {
    "path": "torchplus/tools.py",
    "chars": 2209,
    "preview": "import functools\nimport inspect\nimport sys\nfrom collections import OrderedDict\n\nimport numba\nimport numpy as np\nimport t"
  },
  {
    "path": "torchplus/train/__init__.py",
    "chars": 486,
    "preview": "from torchplus.train.checkpoint import (latest_checkpoint, restore,\n                                        restore_late"
  },
  {
    "path": "torchplus/train/checkpoint.py",
    "chars": 7879,
    "preview": "import json\nimport logging\nimport os\nimport signal\nfrom pathlib import Path\n\nimport torch\n\n\nclass DelayedKeyboardInterru"
  },
  {
    "path": "torchplus/train/common.py",
    "chars": 717,
    "preview": "import datetime\nimport os\nimport shutil\n\ndef create_folder(prefix, add_time=True, add_str=None, delete=False):\n    addit"
  },
  {
    "path": "torchplus/train/fastai_optim.py",
    "chars": 12244,
    "preview": "from collections import Iterable, defaultdict\nfrom copy import deepcopy\nfrom itertools import chain\n\nimport torch\nfrom t"
  },
  {
    "path": "torchplus/train/learning_schedules.py",
    "chars": 7996,
    "preview": "\"\"\"PyTorch edition of TensorFlow learning schedule in tensorflow object\ndetection API. \n\"\"\"\nimport numpy as np\nfrom torc"
  },
  {
    "path": "torchplus/train/learning_schedules_fastai.py",
    "chars": 7234,
    "preview": "import numpy as np\nimport math\nfrom functools import partial\nimport torch\n\n\nclass LRSchedulerStep(object):\n    def __ini"
  },
  {
    "path": "torchplus/train/optim.py",
    "chars": 4081,
    "preview": "from collections import defaultdict, Iterable\n\nimport torch\nfrom copy import deepcopy\nfrom itertools import chain\nfrom t"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/config_io.py",
    "chars": 4175,
    "preview": "\nfrom easydict import EasyDict as edict\nimport os\nimport shutil\nimport yaml\n\n\ndef mkdir_if_not_exists(path):\n    \"\"\"Make"
  },
  {
    "path": "utils/distributed_utils.py",
    "chars": 10771,
    "preview": "from torch.utils.data import Sampler\nimport math\nimport os\nimport pdb\nimport torch\nimport torch.distributed as dist\nfrom"
  },
  {
    "path": "utils/eval_metric.py",
    "chars": 30219,
    "preview": "import os\nimport numpy as np\nfrom plyfile import PlyData\n# from utils import icp_utils\nfrom data.linemod import linemod_"
  },
  {
    "path": "utils/furthest_point_sample.py",
    "chars": 1958,
    "preview": "\nimport numpy as np\nfrom scipy import spatial\n\n\ndef fragmentation_fps(vertices, num_frags):\n  \"\"\"Fragmentation by the fu"
  },
  {
    "path": "utils/geometric.py",
    "chars": 2243,
    "preview": "import numpy as np \n\n\ndef range_to_depth(mask, range, K):\n    '''\n       Transform the range image to depth image\n    ''"
  },
  {
    "path": "utils/img_utils.py",
    "chars": 5601,
    "preview": "import torch\nfrom matplotlib import cm\nimport matplotlib.pyplot as plt\nimport matplotlib.patches as patches\nimport numpy"
  },
  {
    "path": "utils/log_tool.py",
    "chars": 5774,
    "preview": "import numpy as np\nfrom tensorboardX import SummaryWriter\nimport json\nfrom pathlib import Path\nimport logging\n\n\n\ndef _fl"
  },
  {
    "path": "utils/pose_utils.py",
    "chars": 37832,
    "preview": "\"\"\"\nCopyright (C) 2018 NVIDIA Corporation.  All rights reserved.\nLicensed under the CC BY-NC-SA 4.0 license (https://cre"
  },
  {
    "path": "utils/pose_utils_np.py",
    "chars": 40506,
    "preview": "\"\"\"\nCopyright (C) 2018 NVIDIA Corporation.  All rights reserved.\nLicensed under the CC BY-NC-SA 4.0 license (https://cre"
  },
  {
    "path": "utils/progress_bar.py",
    "chars": 6187,
    "preview": "import contextlib\r\nimport enum\r\nimport math\r\nimport time\r\n\r\nimport numpy as np\r\n\r\n\r\ndef progress_str(val, *args, width=2"
  },
  {
    "path": "utils/rand_utils.py",
    "chars": 696,
    "preview": "import numpy as np  \n\ndef truncated_normal(u, sigma, min, max, shape=None):\n    \"\"\" Generate data following truncated no"
  },
  {
    "path": "utils/singleton.py",
    "chars": 253,
    "preview": "# import h5py\n\nclass Singleton(type):\n    _instances = {}\n    def __call__(cls, *args, **kwargs):\n        if cls not in "
  },
  {
    "path": "utils/timer.py",
    "chars": 1894,
    "preview": "import time \nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef simple_timer(name=''):\n    t = time.time()\n    "
  },
  {
    "path": "utils/util.py",
    "chars": 3238,
    "preview": "import torch\nimport numpy as np\nimport collections\n\n\n\ndef freeze_params(params: dict, include: str = None, exclude: str "
  },
  {
    "path": "utils/visualize.py",
    "chars": 2845,
    "preview": "import numpy as np \nimport cv2 \nimport copy\n\ndef vis_pointclouds_cv2(pc, K, win_size, init_transform=None, color=None, i"
  }
]

// ... and 4 more files (download for full content)

About this extraction

This page contains the full source code of the DecaYale/RNNPose GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 126 files (12.7 MB), approximately 223.6k tokens, and a symbol index with 1087 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!