[
  {
    "path": ".gitignore",
    "content": "**/*.old\n**/*.bak\n\n.DS_Store\n# Created by https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks\n# Edit at https://www.toptal.com/developers/gitignore?templates=vscode,python,jupyternotebooks\n\n### JupyterNotebooks ###\n# gitignore template for Jupyter Notebooks\n# website: http://jupyter.org/\n\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\n\n# IPython\nprofile_default/\nipython_config.py\n\n# Remove previous ipynb_checkpoints\n#   git rm -r .ipynb_checkpoints/\n\n### Python ###\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\n#lib/\n#lib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\npytestdebug.log\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\ndoc/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n\n# IPython\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\npythonenv*\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# profiling data\n.prof\n\n### vscode ###\n.vscode/*\n!.vscode/settings.json\n!.vscode/tasks.json\n!.vscode/launch.json\n!.vscode/extensions.json\n*.code-workspace\n\n# End of https://www.toptal.com/developers/gitignore/api/vscode,python,jupyternotebooks\n"
  },
  {
    "path": "LICENSE.md",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization\n\n[Yan Xu](https://decayale.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Guofeng Zhang](http://www.cad.zju.edu.cn/home/gfzhang/), [Xiaogang Wang](https://www.ee.cuhk.edu.hk/en-gb/people/academic-staff/professors/prof-xiaogang-wang), [Hongsheng Li](http://www.ee.cuhk.edu.hk/~hsli/). \n\n*Conference on Computer Vision and Pattern Recognition (CVPR), 2022.*\n\n[[Paper]](https://scholar.google.com/scholar?hl=zh-CN&as_sdt=0%2C5&q=RNNPose%3A+Recurrent+6-DoF+Object+Pose+Refinement+with+Robust+Correspondence+Field+Estimation+and+Pose+Optimization&btnG=)\n\n\n\n\n## 1. Framework \nThe basic pipeline of our proposed RNNPose.  (a) Before refinement, a reference image is rendered according to the object initial pose (shown in a fused view).\n(b) Our RNN-based framework recurrently refines the object pose based on the estimated correspondence field between the reference and target images. The pose is optimized to be consistent with the reliable correspondence estimations highlighted by the similarity score map (built from learned 3D-2D descriptors) via differentiable LM optimization.  (c) The output refined pose.  \n\n<!-- ![image info](./demo/framework.png) -->\n<p align=\"center\">\n<img src=\"./demo/idea.png\" alt=\"alt text\" width=\"450\"/>\n</p>\n\n## 2. Pose Estimation with Occlusions and Erroneous Pose Initializations\n\n\n### Estimated Poses and Intermediate System Outputs from Different Recurrent Iterations. \n\n<p align=\"center\">\n <img src=\"demo/ape_short_small.gif\" alt=\"animated\" height=400/><img src=\"demo/driller_short_small.gif\" alt=\"animated\" height=400/>\n</p>\n\n\n### Pose Estimates with Erroneous Pose Initializations\nVisualization of our pose estimations (first row) on Occlusion LINEMOD dataset and the similarity score maps (second row) for downweighting unreliable correspondences during pose optimization. \nFor pose visualization, the white boxes represent the erroneous initial poses, the red boxes are estimated by our algorithm and the ground-truth boxes are in blue. Here, the initial poses for pose refinement are originally from PVNet but added with significant disturbances for robustness testing. \n<center class=\"half\">\n  <img src=\"./demo/est_vis.png\" height=200 > \n</center>\n\n\n## 3. Installation \n### Install the Docker \nA dockerfile is provided to help with the environment setup. \nYou need to install [docker](https://docs.docker.com/get-docker/) and [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) first and then set up the docker image and start up a container with the following commands: \n\n```\ncd RNNPose/docker\nsudo docker build -t rnnpose .    \nsudo docker run  -it  --runtime=nvidia --ipc=host  --volume=\"HOST_VOLUME_YOU_WANT_TO_MAP:DOCKER_VOLUME\"  -e DISPLAY=$DISPLAY -e QT_X11_NO_MITSHM=1  rnnpose bash\n\n```\nIf you are not familiar with [docker](https://docs.docker.com/get-docker/), you could also install the dependencies manually following the provided dockerfile.  \n\n### Compile the Dependencies\n```\ncd RNNPose/scripts\nbash compile_3rdparty.sh\n```\n\n\n## 4. Data Preparation\nWe follow [DeepIM](https://github.com/liyi14/mx-DeepIM) and [PVNet](https://github.com/zju3dv/pvnet-rendering) to preprocess the training data for LINEMOD. \nYou could follow the steps [here](doc/prepare_data.md) for data preparation. \n\n\n\n## 5. Test with the Pretrained Models\nWe train our model with the mixture of the real data and the synthetic data on LINEMOD dataset. \n<!-- and evaluate the trained models on the test set of LINEMOD and LINEMOD OCCLUSION datasets.  -->\nThe trained models on the LINEMOD dataset have been uploaded to the [OneDrive](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/ESPTVyUryHdGl65fRAxN51gBBayJJb9NpCqWA-tY2CFKJQ?e=R9bcLW). \nYou can download them \nand put them into the directory *weight/* for testing. \n\n\nAn example bash script is provided below for reference. \n\n```\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH\"\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH/thirdparty:$PYTHONPATH\"\n\nseq=cat\ngpu=1\nstart_gpu_id=0\nmkdir $model_dir\n\ntrain_file=/home/yxu/Projects/Works/RNNPose_release/tools/eval.py\nconfig_path=/mnt/workspace/Works/RNNPose_release/config/linemod/\"$seq\"_fw0.5.yml\npretrain=$PROJECT_ROOT_PATH/weights/trained_models/\"$seq\".tckpt\n\npython -u $train_file multi_proc_train  \\\n        --config_path $config_path \\\n        --model_dir $model_dir/results \\\n        --use_dist True \\\n        --dist_port 10000 \\\n        --gpus_per_node $gpu \\\n        --optim_eval True \\\n        --use_apex True \\\n        --world_size $gpu \\\n        --start_gpu_id $start_gpu_id \\\n        --pretrained_path $pretrain \n\n```\n\nNote that you need to specify the PROJECT_ROOT_PATH, i.e. the absolute directory of the project folder *RNNPose* and modify the respective data paths in the configuration files to the locations of downloaded data before executing the commands. You could also refer to the commands below for evaluation with our provide scripts.\n\n### Evaluation on LINEMOD\n```\nbash scripts/eval.sh \n```\n\n### Evaluation on LINEMOD OCCLUSION\n```\nbash scripts/eval_lmocc.sh\n\n```\n\n## Training from Scratch\nAn example training script is provided. \n```\nbash scripts/train.sh \n```\n\n\n\n## 6. Citation\nIf you find our code useful, please cite our paper. \n```\n@inproceedings{xu2022rnnpose,\n  title={RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization},\n  author={Xu, Yan and Kwan-Yee Lin and Zhang, Guofeng and Wang, Xiaogang and  Li, Hongsheng},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  year={2022}\n}\n\n@article{xu2024rnnpose,\n  title={Rnnpose: 6-dof object pose estimation via recurrent correspondence field estimation and pose optimization},\n  author={Xu, Yan and Lin, Kwan-Yee and Zhang, Guofeng and Wang, Xiaogang and Li, Hongsheng},\n  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},\n  year={2024},\n  publisher={IEEE}\n```\n\n\n## 7. Acknowledgement\n\nThe skeleton of this code is borrowed from [RSLO](https://github.com/DecaYale/RSLO). We also would like to thank the public codebases [PVNet](https://github.com/zju3dv/pvnet), [RAFT](https://github.com/princeton-vl/RAFT), [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) and [DeepV2D](https://github.com/princeton-vl/DeepV2D). \n\n<!-- ## TODO List and ETA\n- [x] Inference code and pretrained models (25/12/2021)\n- [ ] Training code\n- [ ] Code cleaning and improvement -->\n\n\n\n\n\n"
  },
  {
    "path": "builder/__init__.py",
    "content": ""
  },
  {
    "path": "builder/dataset_builder.py",
    "content": "\nfrom data.dataset import get_dataset_class\nimport numpy as np\nfrom functools import partial\nfrom data.preprocess import preprocess, preprocess_deepim, patch_crop\n\n\ndef build(input_reader_config,\n          training,\n          ):\n\n    prep_cfg = input_reader_config.preprocess\n    dataset_cfg = input_reader_config.dataset\n    cfg = input_reader_config\n    \n    dataset_cls = get_dataset_class(dataset_cfg.dataset_class_name)\n\n    if 0:# 'DeepIM' in dataset_cfg.dataset_class_name:\n        # patch_cropper = partial(patch_crop, margin_ratio=0.2, output_size=256 )\n        patch_cropper = None \n        prep_func = partial(preprocess_deepim, \n                        max_points=dataset_cfg.max_points,\n                        correspondence_radius=prep_cfg.correspondence_radius_threshold,\n                        patch_cropper=patch_cropper,\n                        image_scale=prep_cfg.get('image_scale', 1),\n        ) \n\n    else:\n        prep_func = partial(preprocess, \n                        max_points=dataset_cfg.max_points,\n                        correspondence_radius=prep_cfg.correspondence_radius_threshold,\n                        image_scale=prep_cfg.get('image_scale', 1),\n                        crop_param=prep_cfg.get('crop_param', None),\n                        kp_3d_param=prep_cfg.get('kp_3d_param', {\"kp_num\":30} ),\n                        use_coords_as_3d_feat=prep_cfg.get('use_coords_as_3d_feat', False)\n        ) \n   \n\n    dataset = dataset_cls(\n        info_path=dataset_cfg.info_path,\n        root_path=dataset_cfg.root_path,\n        model_point_dim=dataset_cfg.model_point_dim,\n        is_train=training,\n        prep_func=prep_func,\n        seq_names=dataset_cfg.get('seq_names', None),\n        cfg=dataset_cfg\n    )\n\n    return dataset\n"
  },
  {
    "path": "builder/input_reader_builder.py",
    "content": "\nfrom torch.utils.data import Dataset\n\nfrom builder import dataset_builder\n\n\nclass DatasetWrapper(Dataset):\n    \"\"\" convert our dataset to Dataset class in pytorch.\n    \"\"\"\n\n    def __init__(self, dataset):\n        self._dataset = dataset\n\n    def __len__(self):\n        return len(self._dataset)\n\n    def __getitem__(self, idx):\n        return self._dataset[idx]\n\n    @property\n    def dataset(self):\n        return self._dataset\n\n\ndef build(input_reader_config,\n          training,\n          ) -> DatasetWrapper:\n\n    dataset = dataset_builder.build(\n        input_reader_config,\n        training,\n    )\n    dataset = DatasetWrapper(dataset)\n    return dataset\n"
  },
  {
    "path": "builder/losses_builder.py",
    "content": "\nfrom model import losses\n\ndef build(loss_config):\n\n    criterions = {}\n  \n    criterions[\"metric_loss\"] =losses.MetricLoss(configs=loss_config.metric_loss,)\n\n    criterions[\"pose_loss\"] = losses.PointAlignmentLoss(loss_weight=1)\n    \n    return criterions\n"
  },
  {
    "path": "builder/lr_scheduler_builder.py",
    "content": "\nfrom torchplus.train import learning_schedules_fastai as lsf\nimport torch\nimport numpy as np \n\ndef build(optimizer_config, optimizer, total_step):\n\n    optimizer_type = list(optimizer_config.keys())[0]\n\n    if optimizer_type == 'rms_prop_optimizer':\n        config = optimizer_config.rms_prop_optimizer\n        lr_scheduler = _create_learning_rate_scheduler(\n            config.learning_rate, optimizer, total_step=total_step)\n\n    if optimizer_type == 'momentum_optimizer':\n        config = optimizer_config.momentum_optimizer\n        lr_scheduler = _create_learning_rate_scheduler(\n            config.learning_rate, optimizer, total_step=total_step)\n\n    if optimizer_type == 'adam_optimizer':\n        config = optimizer_config.adam_optimizer\n        lr_scheduler = _create_learning_rate_scheduler(\n            config.learning_rate, optimizer, total_step=total_step)\n\n    return lr_scheduler\n\n\ndef _create_learning_rate_scheduler(learning_rate_config, optimizer, total_step):\n    \"\"\"Create optimizer learning rate scheduler based on config.\n\n    Args:\n      learning_rate_config: A LearningRate proto message.\n\n    Returns:\n      A learning rate.\n\n    Raises:\n      ValueError: when using an unsupported input data type.\n    \"\"\"\n    lr_scheduler = None\n    # learning_rate_type = learning_rate_config.WhichOneof('learning_rate')\n    learning_rate_type = list(learning_rate_config.keys())[0]\n\n    if learning_rate_type == 'multi_phase':\n        config = learning_rate_config.multi_phase\n        lr_phases = []\n        mom_phases = []\n        for phase_cfg in config.phases:\n            lr_phases.append((phase_cfg.start, phase_cfg.lambda_func))\n            mom_phases.append(\n                (phase_cfg.start, phase_cfg.momentum_lambda_func))\n        lr_scheduler = lsf.LRSchedulerStep(\n            optimizer, total_step, lr_phases, mom_phases)\n\n\n\n    if learning_rate_type == 'one_cycle':\n        config = learning_rate_config.one_cycle\n\n        if len(config.lr_maxs)>1:\n          assert(len(config.lr_maxs)==4 )    \n          lr_max=[]\n          # for i in range(len(config.lr_maxs)):\n          #   lr_max += [config.lr_maxs[i]]*optimizer.param_segs[i] \n\n          lr_max = np.array(list(config.lr_maxs) )\n        else:\n          lr_max = config.lr_max\n\n        lr_scheduler = lsf.OneCycle(\n            optimizer, total_step, lr_max, list(config.moms), config.div_factor, config.pct_start)\n    if learning_rate_type == 'exponential_decay':\n        config = learning_rate_config.exponential_decay\n        lr_scheduler = lsf.ExponentialDecay(\n            optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.staircase)\n    if learning_rate_type == 'exponential_decay_warmup':\n        config = learning_rate_config.exponential_decay_warmup\n        lr_scheduler = lsf.ExponentialDecayWarmup(\n            optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor,   config.div_factor,\n            config.pct_start, config.staircase)\n    if learning_rate_type == 'manual_stepping':\n        config = learning_rate_config.manual_stepping\n        lr_scheduler = lsf.ManualStepping(\n            optimizer, total_step, list(config.boundaries), list(config.rates))\n\n    if lr_scheduler is None:\n        raise ValueError('Learning_rate %s not supported.' %\n                         learning_rate_type)\n\n    return lr_scheduler\n"
  },
  {
    "path": "builder/optimizer_builder.py",
    "content": "from torchplus.train import learning_schedules\nfrom torchplus.train import optim\nimport torch\nfrom torch import nn\nfrom torchplus.train.fastai_optim import OptimWrapper, FastAIMixedOptim\nfrom functools import partial\n\n\ndef children(m: nn.Module):\n    \"Get children of `m`.\"\n    return list(m.children())\n\n\ndef num_children(m: nn.Module) -> int:\n    \"Get number of children modules in `m`.\"\n    return len(children(m))\n\n# return a list of smallest modules dy\n\n\ndef flatten_model(m):\n    if m is None:\n        return []\n    return sum(\n        map(flatten_model, m.children()), []) if num_children(m) else [m]\n\n\n# def get_layer_groups(m): return [nn.Sequential(*flatten_model(m))]\ndef get_layer_groups(m): return [nn.ModuleList(flatten_model(m))]\n\ndef get_voxeLO_net_layer_groups(net):\n    vfe_grp = get_layer_groups(net)#[0]\n\n    other_grp = get_layer_groups(nn.Sequential(net._rotation_loss,\n                    net._translation_loss,\n                    net._pyramid_rotation_loss,\n                        net._pyramid_translation_loss,\n                     net._consistency_loss, \n                     ))\n\n    return [vfe_grp, mfe_grp, op_grp,other_grp]\n\n\ndef get_voxeLO_net_layer_groups(net):\n    vfe_grp = get_layer_groups(net.voxel_feature_extractor)#[0]\n    mfe_grp = get_layer_groups(net.middle_feature_extractor)#[0]\n    op_grp = get_layer_groups(net.odom_predictor)#[0]\n\n    # other_grp = get_layer_groups(net._rotation_loss) +  \\\n    #     get_layer_groups(net._translation_loss) \\\n    #         + get_layer_groups(net._pyramid_rotation_loss) \\\n    #             + get_layer_groups(net._pyramid_translation_loss) \\\n    #                 + get_layer_groups(net._consistency_loss)\\\n    other_grp = get_layer_groups(nn.Sequential(net._rotation_loss,\n                    net._translation_loss,\n                    net._pyramid_rotation_loss,\n                        net._pyramid_translation_loss,\n                     net._consistency_loss, \n                     ))\n\n    return [vfe_grp, mfe_grp, op_grp,other_grp]\n\n\ndef build(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):\n\n    optimizer_type = list(optimizer_config.keys())[0]\n    print(\"Optimizer:\", optimizer_type)\n    \n    optimizer=None\n\n    if optimizer_type == 'rms_prop_optimizer':\n        config=optimizer_config.rms_prop_optimizer\n        optimizer_func=partial(\n            torch.optim.RMSprop,\n            alpha=config.decay,\n            momentum=config.momentum_optimizer_value,\n            eps=config.epsilon)\n\n    if optimizer_type == 'momentum_optimizer':\n        config=optimizer_config.momentum_optimizer\n        optimizer_func=partial(\n            torch.optim.SGD,\n            momentum=config.momentum_optimizer_value,\n            eps=config.epsilon)\n\n    if optimizer_type == 'adam_optimizer':\n        config=optimizer_config.adam_optimizer\n        if optimizer_config.fixed_weight_decay:\n            optimizer_func=partial(\n                torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad)\n        else:\n            # regular adam\n            optimizer_func=partial(\n                torch.optim.Adam, amsgrad=config.amsgrad)\n\n    optimizer=OptimWrapper.create(\n        optimizer_func,\n        3e-3,\n        get_layer_groups(net),\n        # get_voxeLO_net_layer_groups(net),\n        wd=config.weight_decay,\n        true_wd=optimizer_config.fixed_weight_decay,\n        bn_wd=True)\n    print(hasattr(optimizer, \"_amp_stash\"), '_amp_stash')\n    if optimizer is None:\n        raise ValueError('Optimizer %s not supported.' % optimizer_type)\n\n    if optimizer_config.use_moving_average:\n        raise ValueError('torch don\\'t support moving average')\n    if name is None:\n        # assign a name to optimizer for checkpoint system\n        optimizer.name=optimizer_type\n    else:\n        optimizer.name=name\n    return optimizer\n"
  },
  {
    "path": "builder/rnnpose_builder.py",
    "content": "from builder import losses_builder\nfrom model.RNNPose import get_posenet_class\nimport model.RNNPose\n\n\ndef build(model_cfg,\n          measure_time=False, testing=False):\n    \"\"\"build second pytorch instance.\n    \"\"\"\n\n    criterions=losses_builder.build(model_cfg.loss)\n\n    net = get_posenet_class(model_cfg.network_class_name)(\n        criterions=criterions,\n        opt=model_cfg)\n    return net\n"
  },
  {
    "path": "config/default.py",
    "content": "from yacs.config import CfgNode as CN\nfrom utils.singleton import Singleton\nimport os\n\ndef _merge_a_into_b(a, b):\n    \"\"\"Merge config dictionary a into config dictionary b, clobbering the\n    options in b whenever they are also specified in a.\n    \"\"\"\n    # if type(a) is not dict:\n    if not isinstance(a, (dict)):\n        return\n\n    for k, v in a.items():\n        # a must specify keys that are in b\n        if not k in b:\n            raise KeyError('{} is not a valid config key'.format(k))\n\n        # the types must match, too\n        old_type = type(b[k])\n        if old_type is not type(v):\n            if isinstance(b[k], np.ndarray):\n                v = np.array(v, dtype=b[k].dtype)\n            else:\n                raise ValueError(('Type mismatch ({} vs. {}) '\n                                'for config key: {}').format(type(b[k]),\n                                                            type(v), k))\n\n        # recursively merge dicts\n        # if type(v) is dict:\n        if isinstance(v, (dict)):\n            try:\n                _merge_a_into_b(a[k], b[k])\n            except:\n                print('Error under config key: {}'.format(k))\n                raise\n        else:\n            b[k] = v\n\n\n\nclass Config(metaclass=Singleton):\n    def __init__(self):\n        ##############  ↓  Basic   ↓  ##############\n        self.ROOT_CN = CN()\n        self.ROOT_CN.BASIC = CN()\n        self.ROOT_CN.BASIC.input_size=[480,640] #h,w\n        self.ROOT_CN.BASIC.crop_size=[320,320] #h,w\n        self.ROOT_CN.BASIC.zoom_crop_size=[320,320] #h,w\n        self.ROOT_CN.BASIC.render_image_size=[320,320]#h,w\n        self.ROOT_CN.BASIC.patch_num=64#h,w\n\n        ##############  ↓  LM OPTIM   ↓  ##############\n        self.ROOT_CN.LM=CN()\n        self.ROOT_CN.LM.LM_LMBDA= 0.0001\n        self.ROOT_CN.LM.EP_LMBDA=100\n\n        ##############  ↓  data   ↓  ##############\n        self.ROOT_CN.DATA=CN()\n        self.ROOT_CN.DATA.OBJ_ROOT=\"\" #h,w\n        self.ROOT_CN.DATA.VOC_ROOT=f\"{os.path.dirname(os.path.abspath(__file__)) }/../EXPDATA/\" #h,w\n\n    def __get_item__(self, key):\n        return self.ROOT_CN.__getitem__(key)\n    \n    def merge(self, config_dict, sub_key=None):\n\n        if sub_key is not None:\n            _merge_a_into_b(config_dict, self.ROOT_CN[sub_key])\n        else:\n            _merge_a_into_b(config_dict, self.ROOT_CN)\n\n##############  ↓  Model  ↓  ##############\n# _CN.model = CN()\n# _CN.model.input_size=[480,640]\n# _CN.model.crop_size=[320,320] \n# def get_cfg_defaults():\n\ndef get_cfg(Node=None):\n    \"\"\"Get a yacs CfgNode object with default values for my_project.\"\"\"\n    # Return a clone so that the defaults will not be altered\n    # This is for the \"local variable\" use pattern\n    # return _CN.clone()\n    if Node is None:\n        return  Config()\n    else:\n        return Config().__get_item__(Node)\n"
  },
  {
    "path": "config/linemod/copy.sh",
    "content": "declare -a arr=(\"glue\" \"ape\" \"cat\" \"phone\" \"eggbox\" \"benchvise\" \"lamp\" \"camera\" \"can\" \"driller\" \"duck\" \"holepuncher\" \"iron\"  )\n\n#create training scripts\nfor seq in \"${arr[@]}\"\ndo\n   echo \"$seq\"\n   cat template_fw0.5.yml > \"$seq\"_fw0.5.yml\n   sed -i \"s/SEQ_NAME/$seq/g\" \"$seq\"_fw0.5.yml\ndone\n\narraylength=${#arr[@]}\n"
  },
  {
    "path": "config/linemod/copy_occ.sh",
    "content": "declare -a arr=(\"glue\" \"ape\" \"cat\" \"phone\" \"eggbox\" \"benchvise\" \"lamp\" \"camera\" \"can\" \"driller\" \"duck\" \"holepuncher\" \"iron\"  )\n\n#create training scripts\nfor seq in \"${arr[@]}\"\ndo\n   echo \"$seq\"\n   cat template_fw0.5_occ.yml > \"$seq\"_fw0.5_occ.yml\n   sed -i \"s/SEQ_NAME/$seq/g\" \"$seq\"_fw0.5_occ.yml\ndone\n\n\n"
  },
  {
    "path": "config/linemod/template_fw0.5.yml",
    "content": "vars:\n  input_h: &input_h\n    320 \n  input_w: &input_w\n    320 \n  batch_size: &batch_size\n    1\n  descriptor_dim: &descriptor_dim\n    32 \n  correspondence_radius_threshold: &correspondence_radius_threshold\n    0.01 #0.04 \n  seq_name: &seq_name\n    [\"SEQ_NAME\"]\nBASIC:\n  zoom_crop_size: [240,240]\nmodel:\n  input_h: *input_h\n  input_w: *input_w\n  batch_size: *batch_size\n  seq_len: 2 \n\n\n  network_class_name: RNNPose \n  descriptor_net:\n    module_class_name: HybridFeaNet \n    \n    keypoints_detector_2d:\n      input_dim: 3\n      descriptor_dim: *descriptor_dim \n      remove_borders: 4\n      normalize_output: True \n\n    keypoints_detector_3d:\n      #KPCONV configurations\n      num_layers: 4\n      KP_extent: 2.0\n      batch_norm_momentum: 0.02\n      use_batch_norm: true\n      in_points_dim: 3\n      fixed_kernel_points: 'center' #['center', 'verticals', 'none']\n      KP_influence: 'linear'\n      aggregation_mode: 'sum' #['closest', 'sum']\n      modulated: false \n      first_subsampling_dl:  0.025\n      conv_radius: 2.5\n      deform_radius: 5\n      in_features_dim: 1 #3\n      first_feats_dim: 128\n      num_kernel_points: 15\n      final_feats_dim: *descriptor_dim #256 #32\n      normalize_output: True \n      gnn_feats_dim: 128 #256\n    context_fea_extractor_3d:\n      #KPCONV configurations\n      num_layers: 4\n      KP_extent: 2.0\n      batch_norm_momentum: 0.02\n      use_batch_norm: true\n      in_points_dim: 3\n      fixed_kernel_points: 'center' #['center', 'verticals', 'none']\n      KP_influence: 'linear'\n      aggregation_mode: 'sum' #['closest', 'sum']\n      modulated: false \n      first_subsampling_dl:  0.025\n      conv_radius: 2.5\n      deform_radius: 5\n      in_features_dim: 1 #3\n      first_feats_dim: 128\n      num_kernel_points: 15\n      final_feats_dim: 256 #*descriptor_dim #256 #32\n      normalize_output: False \n      gnn_feats_dim: 128 #256\n  motion_net:\n    IS_CALIBRATED: True\n    RESCALE_IMAGES: False\n    ITER_COUNT: 4 \n    RENDER_ITER_COUNT: 3 #2 #1 #3\n    TRAIN_RESIDUAL_WEIGHT: 0 #0.1 \n    TRAIN_FLOW_WEIGHT: 0.5 #0.1 #1 \n    TRAIN_REPROJ_WEIGHT: 0 \n    OPTIM_ITER_COUNT: 1\n    FLOW_NET: 'raft' \n    SYN_OBSERVED: False\n    ONLINE_CROP: True\n    raft:\n      small: False #True\n      fea_net: \"default\"\n      mixed_precision: True\n      # pretrained_model: \"/mnt/workspace/datasets/weights/models/raft-small.pth\"\n      pretrained_model: \"/mnt/workspace/datasets/weights/models/raft-chairs.pth\"\n      input_dim: 3 \n      iters: 1 \n      \n  loss:\n    metric_loss:\n      type: \"normal\" \n      pos_radius: *correspondence_radius_threshold # the radius used to find the positive correspondences\n      safe_radius: 0.02 #0.13\n      pos_margin: 0.1\n      neg_margin: 1.4\n      max_points: 256\n      matchability_radius: 0.06\n      weight: 0.001\n    saliency_loss:\n      loss_weight: 1\n      reg_weight: 0.01\n    geometric_loss:\n      loss_weight: 1\n      reg_weight: 0.5 #0.1 \n\n\ntrain_config:\n\n  optimizer: \n      adam_optimizer: \n        learning_rate: \n          one_cycle: \n            lr_maxs: []\n            lr_max: 0.0001 # \n            moms: [0.95, 0.85]\n            div_factor: 10.0\n            pct_start: 0.01 #0.05\n        amsgrad: false\n        weight_decay: 0.0001 \n      fixed_weight_decay: true\n      use_moving_average: false\n\n  steps: 200000 \n  steps_per_eval:  10000\n  loss_scale_factor: -1\n  clear_metrics_every_epoch: true\n\ntrain_input_reader:\n  dataset:\n    dataset_class_name: \"LinemodDeepIMSynRealV2\" \n    info_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_orig_deepim.info.train\",  \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_syn_deepim.info.train\", \n        \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_fusesformatted_all10k_deepim.info.train\",\n    ] \n    root_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine\",  \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine_syn\",\n     \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LINEMOD/fuse_formatted/\"\n     ]\n\n    model_point_dim: 3\n    max_points: 20000\n    seq_names: *seq_name \n\n  batch_size: *batch_size \n  preprocess: \n    correspondence_radius_threshold: *correspondence_radius_threshold\n    num_workers: 3 \n    image_scale: 1\n    crop_param:\n      rand_crop: false\n      margin_ratio: 0.85 \n      output_size: *input_h \n      crop_with_init_pose: True\n    \n\neval_input_reader:\n  dataset:\n    dataset_class_name: \"LinemodDeepIMSynRealV2\"\n    info_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_posecnn.info.eval\" ] \n    root_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine\" ]\n    model_point_dim: 3\n    max_points: 20000\n    seq_names: *seq_name \n  batch_size: *batch_size \n  preprocess:\n    correspondence_radius_threshold: *correspondence_radius_threshold\n    num_workers: 3 \n    image_scale: 1\n    crop_param:\n      rand_crop: false\n      margin_ratio: 0.85 #0.5 \n      output_size: *input_h #\n      crop_with_init_pose: True\n      \n"
  },
  {
    "path": "config/linemod/template_fw0.5_occ.yml",
    "content": "vars:\n  input_h: &input_h\n    320 \n  input_w: &input_w\n    320 \n  batch_size: &batch_size\n    1\n  descriptor_dim: &descriptor_dim\n    32 \n  correspondence_radius_threshold: &correspondence_radius_threshold\n    0.01 #0.04 \n  seq_name: &seq_name\n    [\"SEQ_NAME\"]\nBASIC:\n  zoom_crop_size: [240,240]\nmodel:\n  input_h: *input_h\n  input_w: *input_w\n  batch_size: *batch_size\n  seq_len: 2 \n\n\n  network_class_name: RNNPose \n  descriptor_net:\n    module_class_name: HybridFeaNet \n    \n    keypoints_detector_2d:\n      input_dim: 3\n      descriptor_dim: *descriptor_dim \n      remove_borders: 4\n      normalize_output: True \n\n    keypoints_detector_3d:\n      #KPCONV configurations\n      num_layers: 4\n      KP_extent: 2.0\n      batch_norm_momentum: 0.02\n      use_batch_norm: true\n      in_points_dim: 3\n      fixed_kernel_points: 'center' #['center', 'verticals', 'none']\n      KP_influence: 'linear'\n      aggregation_mode: 'sum' #['closest', 'sum']\n      modulated: false \n      first_subsampling_dl:  0.025\n      conv_radius: 2.5\n      deform_radius: 5\n      in_features_dim: 1 #3\n      first_feats_dim: 128\n      num_kernel_points: 15\n      final_feats_dim: *descriptor_dim #256 #32\n      normalize_output: True \n      gnn_feats_dim: 128 #256\n    context_fea_extractor_3d:\n      #KPCONV configurations\n      num_layers: 4\n      KP_extent: 2.0\n      batch_norm_momentum: 0.02\n      use_batch_norm: true\n      in_points_dim: 3\n      fixed_kernel_points: 'center' #['center', 'verticals', 'none']\n      KP_influence: 'linear'\n      aggregation_mode: 'sum' #['closest', 'sum']\n      modulated: false \n      first_subsampling_dl:  0.025\n      conv_radius: 2.5\n      deform_radius: 5\n      in_features_dim: 1 #3\n      first_feats_dim: 128\n      num_kernel_points: 15\n      final_feats_dim: 256 #*descriptor_dim #256 #32\n      normalize_output: False \n      gnn_feats_dim: 128 #256\n  motion_net:\n    IS_CALIBRATED: True\n    RESCALE_IMAGES: False\n    ITER_COUNT: 4 \n    RENDER_ITER_COUNT: 3 #2 #1 #3\n    TRAIN_RESIDUAL_WEIGHT: 0 #0.1 \n    TRAIN_FLOW_WEIGHT: 0.5 #0.1 #1 \n    TRAIN_REPROJ_WEIGHT: 0 \n    OPTIM_ITER_COUNT: 1\n    FLOW_NET: 'raft' \n    SYN_OBSERVED: False\n    ONLINE_CROP: True\n    raft:\n      small: False #True\n      fea_net: \"default\"\n      mixed_precision: True\n      # pretrained_model: \"/mnt/workspace/datasets/weights/models/raft-small.pth\"\n      pretrained_model: \"/mnt/workspace/datasets/weights/models/raft-chairs.pth\"\n      input_dim: 3 \n      iters: 1 \n      \n  loss:\n    metric_loss:\n      type: \"normal\" \n      pos_radius: *correspondence_radius_threshold # the radius used to find the positive correspondences\n      safe_radius: 0.02 #0.13\n      pos_margin: 0.1\n      neg_margin: 1.4\n      max_points: 256\n      matchability_radius: 0.06\n      weight: 0.001\n    saliency_loss:\n      loss_weight: 1\n      reg_weight: 0.01\n    geometric_loss:\n      loss_weight: 1\n      reg_weight: 0.5 #0.1 \n\n\ntrain_config:\n\n  optimizer: \n      adam_optimizer: \n        learning_rate: \n          one_cycle: \n            lr_maxs: []\n            lr_max: 0.0001 # \n            moms: [0.95, 0.85]\n            div_factor: 10.0\n            pct_start: 0.01 #0.05\n        amsgrad: false\n        weight_decay: 0.0001 \n      fixed_weight_decay: true\n      use_moving_average: false\n\n  steps: 200000 \n  steps_per_eval:  10000\n  loss_scale_factor: -1\n  clear_metrics_every_epoch: true\n\ntrain_input_reader:\n  dataset:\n    dataset_class_name: \"LinemodDeepIMSynRealV2\" \n    info_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_orig_deepim.info.train\",  \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_syn_deepim.info.train\", \n        \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/linemod_fusesformatted_all10k_deepim.info.train\",\n    ] \n    root_path: [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine\",  \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LM6d_converted/LM6d_refine_syn\",\n     \"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/LINEMOD/fuse_formatted/\"\n     ]\n\n    model_point_dim: 3\n    max_points: 20000\n    seq_names: *seq_name \n\n  batch_size: *batch_size \n  preprocess: \n    correspondence_radius_threshold: *correspondence_radius_threshold\n    num_workers: 3 \n    image_scale: 1\n    crop_param:\n      rand_crop: false\n      margin_ratio: 0.85 \n      output_size: *input_h \n      crop_with_init_pose: True\n    \n\neval_input_reader:\n  dataset:\n    dataset_class_name: \"LinemodDeepIMSynRealV2\"\n    info_path:  [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/data_info/deepim/linemod_bop_lmocc_pvnetdr.info.eval\"]\n    root_path:  [\"/home/RNNPose/Projects/Works/RNNPose_release/EXPDATA/lmo\"] \n    init_post_type: \"PVNET_LINEMOD_OCC\"\n    model_point_dim: 3\n    max_points: 20000\n    seq_names: *seq_name \n  batch_size: *batch_size \n  preprocess:\n    correspondence_radius_threshold: *correspondence_radius_threshold\n    num_workers: 3 \n    image_scale: 1\n    crop_param:\n      rand_crop: false\n      margin_ratio: 0.85 #0.5 \n      output_size: *input_h #\n      crop_with_init_pose: True\n      \n      \n"
  },
  {
    "path": "data/__init__.py",
    "content": "from . import dataset\nfrom . import linemod_dataset"
  },
  {
    "path": "data/dataset.py",
    "content": "import pathlib\nimport pickle\nimport time\nfrom functools import partial\n\nimport numpy as np\n\n\nREGISTERED_DATASET_CLASSES = {}\n\n\ndef register_dataset(cls, name=None):\n    global REGISTERED_DATASET_CLASSES\n    if name is None:\n        name = cls.__name__\n    assert name not in REGISTERED_DATASET_CLASSES, f\"exist class: {REGISTERED_DATASET_CLASSES}\"\n    REGISTERED_DATASET_CLASSES[name] = cls\n    return cls\n\n\ndef get_dataset_class(name):\n    global REGISTERED_DATASET_CLASSES\n    assert name in REGISTERED_DATASET_CLASSES, f\"available class: {REGISTERED_DATASET_CLASSES}\"\n    return REGISTERED_DATASET_CLASSES[name]\n\n\nclass Dataset(object):\n    NumPointFeatures = -1\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def __len__(self):\n        raise NotImplementedError\n\n    def _read_data(self, query):\n\n        raise NotImplementedError\n\n    def evaluation(self, dt_annos, output_dir):\n        \"\"\"Dataset must provide a evaluation function to evaluate model.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "data/linemod/linemod_config.py",
    "content": "import numpy as np\ndiameters = {\n    'cat': 15.2633,\n    'ape': 9.74298,\n    'benchvise': 28.6908,\n    'bowl': 17.1185,\n    'cam': 17.1593,\n    'camera': 17.1593,\n    'can': 19.3416,\n    'cup': 12.5961,\n    'driller': 25.9425,\n    'duck': 10.7131,\n    'eggbox': 17.6364,\n    'glue': 16.4857,\n    'holepuncher': 14.8204,\n    'iron': 30.3153,\n    'lamp': 28.5155,\n    'phone': 20.8394\n}\n\nlinemod_cls_names = ['ape', 'cam', 'cat', 'duck', 'glue', 'iron', 'phone', 'benchvise', 'can', 'driller', 'eggbox', 'holepuncher', 'lamp']\n\nlinemod_K = np.array([[572.4114, 0., 325.2611],\n                  [0., 573.57043, 242.04899],\n                  [0., 0., 1.]])\n\n\nblender_K = np.array([[700., 0., 320.],\n                    [0., 700., 240.],\n                    [0., 0., 1.]])"
  },
  {
    "path": "data/linemod_dataset.py",
    "content": "import numpy as np \nimport random\nimport os \nfrom data.dataset import Dataset, register_dataset\nimport pickle\nimport PIL\nimport cv2\nimport torch\nimport time\nimport scipy\n\nfrom utils.geometric import range_to_depth, render_pointcloud\nfrom .transforms import make_transforms \nfrom thirdparty.kpconv.lib.utils import square_distance\nfrom utils.geometric import rotation_angle\n# from utils.visualize import *\nfrom transforms3d.quaternions import mat2quat, quat2mat, qmult\nfrom transforms3d.euler import mat2euler, euler2mat, euler2quat, quat2euler\nimport math\nfrom config.default import get_cfg\n\nCURRENT_DIR=os.path.dirname(os.path.abspath(__file__))\n\ntry:\n    from pytorch3d.io import load_obj, load_ply\nexcept:\n    print(\"Warning: error occurs when importing pytorch3d \")\n    pass\n\n\ndef se3_q2m(se3_q):\n    assert se3_q.size == 7\n    se3_mx = np.zeros((3, 4))\n    # quat = se3_q[0:4] / LA.norm(se3_q[0:4])\n    quat = se3_q[:4]\n    R = quat2mat(quat)\n    se3_mx[:, :3] = R\n    se3_mx[:, 3] = se3_q[4:]\n    return se3_mx\n\ndef info_convertor(info,):\n    \"\"\"\n        [Transform the original kitti info file]\n    \"\"\"\n\n    seqs = info.keys() #['cat']#\n    seq_lengths = [len(info[i]) for i in seqs]\n    data = []\n    for seq in seqs:\n        print(seq)\n        data.append(info[seq])\n\n    new_infos = {\n        \"seqs\": list(seqs),\n        \"seq_lengths\": seq_lengths,\n        \"data\": data\n    }\n    return new_infos\n\ndef resize(im, target_size, max_size, stride=0, interpolation=cv2.INTER_LINEAR):\n    \"\"\"\n    only resize input image to target size and return scale\n    :param im: BGR image input by opencv\n    :param target_size: one dimensional size (the short side)\n    :param max_size: one dimensional max size (the long side)\n    :param stride: if given, pad the image to designated stride\n    :param interpolation: if given, using given interpolation method to resize image\n    :return:\n    \"\"\"\n    im_shape = im.shape\n    im_size_min = np.min(im_shape[0:2])\n    im_size_max = np.max(im_shape[0:2])\n    im_scale = float(target_size) / float(im_size_min)\n    # prevent bigger axis from being more than max_size:\n    if np.round(im_scale * im_size_max) > max_size:\n        im_scale = float(max_size) / float(im_size_max)\n    im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=interpolation)\n\n    if stride == 0:\n        return im, im_scale\n    else:\n        # pad to product of stride\n        im_height = int(np.ceil(im.shape[0] / float(stride)) * stride)\n        im_width = int(np.ceil(im.shape[1] / float(stride)) * stride)\n        im_channel = im.shape[2]\n        padded_im = np.zeros((im_height, im_width, im_channel))\n        padded_im[: im.shape[0], : im.shape[1], :] = im\n        return padded_im, im_scale\ndef sample_poses(pose_tgt):\n    SYN_STD_ROTATION = 15\n    SYN_STD_TRANSLATION = 0.01\n    ANGLE_MAX=45\n    pose_src = pose_tgt.copy()\n    num = pose_tgt.shape[0]\n    for i in range(num):\n        euler = mat2euler(pose_tgt[i, :3, :3])\n        euler += SYN_STD_ROTATION * np.random.randn(3) * math.pi / 180.0\n        pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])\n\n        pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)\n        pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)\n        pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)\n\n        r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)/math.pi*180\n\n        while r_dist > ANGLE_MAX:#or not (16 < center_x < (640 - 16) and 16 < center_y < (480 - 16)):\n            # print(\"r_dist > ANGLE_MAX, resampling...\")\n            print(\"Too large angular differences. Resample the pose...\")\n            euler = mat2euler(pose_tgt[i, :3, :3])\n            euler += SYN_STD_ROTATION * np.random.randn(3) * math.pi / 180.0\n            pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])\n\n            pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)\n            pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)\n            pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)\n\n            r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)*math.pi/180\n    return pose_src.squeeze()\n\n\n\n\n@register_dataset\nclass LinemodDeepIMSynRealV2(Dataset):\n    # use deepim 3d model for geometric feature extraction, mingle the synthetic and real data  \n    def __init__(self, root_path,\n                 info_path, model_point_dim,\n                 is_train,\n                 prep_func=None,\n                 seq_names=None, \n                 cfg={}\n                 ):\n        super().__init__()\n\n        assert info_path is not None\n        assert isinstance(root_path, (tuple, list)) and isinstance(info_path, (tuple, list))\n        assert len(root_path) == len(info_path)\n        print(\"Info:\",info_path)\n        # assert split in ['train', 'val', 'test']\n        self.is_train = is_train\n        self.VOC_ROOT = get_cfg('DATA').VOC_ROOT#\"/DATA/yxu/LINEMOD_DEEPIM/\"\n\n        infos=[]\n        for ipath in info_path:\n            with open(ipath, 'rb') as f:\n                info = pickle.load(f)\n\n                if seq_names is not None:\n                    for k in list(info.keys()):\n                        if k not in seq_names:\n                            del info[k]\n                infos.append( info_convertor(info) )\n\n        #merge multiple infos \n        self.infos = infos[0]\n        self.infos['dataset_idx'] = [0]*len(self.infos['seqs'])\n        for i, info in enumerate(infos[1:]):\n            for k in self.infos:\n                if k == 'dataset_idx':\n                    self.infos[k].extend([i+1]*len(info['seqs']))\n                else:\n                    self.infos[k].extend(info[k])\n\n\n        self.root_paths = root_path\n        self.model_point_dim = model_point_dim\n        # self.max_points=max_points#30000\n        self.prep_func=prep_func\n        # self.rgb_transformer = None #make_transforms(None, is_train=is_train)\n        self.rgb_transformer = make_transforms(None, is_train=is_train)\n        print(\"dataset size:\",self.__len__())\n\n        self.init_pose_type = cfg.get(\"init_post_type\", \"POSECNN_LINEMOD\" ) \n        # self.init_pose_type = cfg.get(\"init_post_type\", \"PVNET_LINEMOD_OCC\" ) \n#         self.init_pose_type = cfg.get(\"init_post_type\", \"PVNET_LINEMOD\" ) \n        print(\"INIT_POSE_TYPE:\", self.init_pose_type)\n        #Load posecnn results\n        if not self.is_train:\n            with open(f\"{CURRENT_DIR}/../EXPDATA/init_poses/linemod_posecnn_results.pkl\", 'rb') as f:\n                self.pose_cnn_results_test_posecnn=pickle.load(f)\n            try:\n                if self.init_pose_type == \"POSECNN_LINEMOD\":\n                    #load posecnn results \n                    self.pose_cnn_results_test=self.pose_cnn_results_test_posecnn\n                elif self.init_pose_type ==\"PVNET_LINEMOD\":\n                    self.pose_cnn_results_test=np.load(f\"{CURRENT_DIR}/../EXPDATA/init_poses/pvnet/pvnet_linemod_test.npy\", allow_pickle=True).flat[0]\n                elif self.init_pose_type ==\"PVNET_LINEMOD_OCC\":\n                    self.pose_cnn_results_test=np.load(f\"{CURRENT_DIR}/../EXPDATA/init_poses/pvnet/pvnet_linemodocc_test.npy\", allow_pickle=True).flat[0]\n                else: \n                    raise NotImplementedError\n            except:\n                print(\"Loading posecnn results failed!\")\n                self.pose_cnn_results_test=None\n            try:\n                # self.blender_to_bop_pose=np.load(f\"{CURRENT_DIR}/../EXPDATA/init_poses/metricpose/blender2bop_RT.npy\", allow_pickle=True).flat[0]\n                self.blender_to_bop_pose=np.load(f\"{CURRENT_DIR}/../EXPDATA/init_poses/pose_conversion/blender2bop_RT.npy\", allow_pickle=True).flat[0]\n            except:\n                print(\"Loading pose conversion matrix failed!\")\n                self.blender_to_bop_pose=None \n                \n        else:\n            self.pose_cnn_results_test=None\n            self.blender_to_bop_pose=None\n        \n    def load_random_background(self, im_observed, mask):\n        VOC_root = os.path.join(self.VOC_ROOT, \"VOCdevkit/VOC2012\")\n        VOC_image_set_dir = os.path.join(VOC_root, \"ImageSets/Main\")\n        VOC_bg_list_path = os.path.join(VOC_image_set_dir, \"diningtable_trainval.txt\")\n        with open(VOC_bg_list_path, \"r\") as f:\n            VOC_bg_list = [\n                line.strip(\"\\r\\n\").split()[0] for line in f.readlines() if line.strip(\"\\r\\n\").split()[1] == \"1\"\n            ]\n        height, width, channel = im_observed.shape\n        target_size = min(height, width)\n        max_size = max(height, width)\n        observed_hw_ratio = float(height) / float(width)\n\n        k = random.randint(0, len(VOC_bg_list) - 1)\n        bg_idx = VOC_bg_list[k]\n        bg_path = os.path.join(VOC_root, \"JPEGImages/{}.jpg\".format(bg_idx))\n        bg_image = cv2.imread(bg_path, cv2.IMREAD_COLOR)[...,::-1] #RGB\n        bg_h, bg_w, bg_c = bg_image.shape\n        bg_image_resize = np.zeros((height, width, channel), dtype=\"uint8\")\n        if (float(height) / float(width) < 1 and float(bg_h) / float(bg_w) < 1) or (\n            float(height) / float(width) >= 1 and float(bg_h) / float(bg_w) >= 1\n        ):\n            if bg_h >= bg_w:\n                bg_h_new = int(np.ceil(bg_w * observed_hw_ratio))\n                if bg_h_new < bg_h:\n                    bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]\n                else:\n                    bg_image_crop = bg_image\n            else:\n                bg_w_new = int(np.ceil(bg_h / observed_hw_ratio))\n                if bg_w_new < bg_w:\n                    bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]\n                else:\n                    bg_image_crop = bg_image\n        else:\n            if bg_h >= bg_w:\n                bg_h_new = int(np.ceil(bg_w * observed_hw_ratio))\n                bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]\n            else:  # bg_h < bg_w\n                bg_w_new = int(np.ceil(bg_h / observed_hw_ratio))\n                print(bg_w_new)\n                bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]\n\n        bg_image_resize_0, _ = resize(bg_image_crop, target_size, max_size)\n        h, w, c = bg_image_resize_0.shape\n        bg_image_resize[0:h, 0:w, :] = bg_image_resize_0\n\n        # add background to image_observed\n        res_image = bg_image_resize.copy()\n        res_image[mask>0]=im_observed[mask>0]\n\n        # im_observed = res_image\n        return res_image\n\n    def _read_data(self, idx):\n        \"\"\"\n        info structure:\n        {\n            'cat':[\n                {\n                \"index\": idx,\n                \"model_path\": str,\n                \"rgb_path\": str,\n                \"depth_path\": str,\n                \"RT\": np.array([3,4]),\n                \"K\":  np.array([3,3]),\n                },\n                {\n                \"index\": idx,\n                \"model_path\": str,\n                \"rgb_path\": str,\n                \"depth_path\": str,\n                \"RT\": np.array([3,4]),\n                \"K\":  np.array([3,3]),\n                }\n            ...\n            ],\n            'dog':[\n\n            ]\n            ...\n        }\n\n        \"\"\"\n\n        if isinstance(idx, (tuple, list)):\n            idx, seed = idx\n        else:\n            seed = None\n\n        seq_lengths = np.array(self.infos['seq_lengths'])\n        seq_lengths_cum = np.cumsum(seq_lengths)\n        seq_lengths_cum = np.insert(seq_lengths_cum, 0, 0)  # insert a dummy 0\n        seq_idx = np.nonzero(seq_lengths_cum > idx)[0][0]-1\n\n        frame_idx = idx - seq_lengths_cum[seq_idx]\n\n        info = self.infos[\"data\"][seq_idx]\n        dataset_idx = self.infos[\"dataset_idx\"][seq_idx]\n        \n\n        model_points_path = os.path.join(f'{os.path.dirname(__file__)}/../EXPDATA/LM6d_converted/models/{self.infos[\"seqs\"][seq_idx]}/textured.obj' ) # TODO: need check\n\n        rgb_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['rgb_observed_path']) \n        depth_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['depth_gt_observed_path']) \n\n        if info[frame_idx].get('rgb_noisy_rendered', None) is not None:\n            rgb_noisy_rendered_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['rgb_noisy_rendered']) \n        else:\n            rgb_noisy_rendered_path = None\n        if info[frame_idx].get('depth_noisy_rendered', None) is not None:\n            depth_noisy_rendered_path = os.path.join(self.root_paths[dataset_idx], info[frame_idx]['depth_noisy_rendered']) \n        else:\n            depth_noisy_rendered_path = None\n\n        if info[frame_idx].get('pose_noisy_rendered', None) is not None:\n            rendered_RT = info[frame_idx]['pose_noisy_rendered'].astype(np.float32)\n        # else:\n        elif self.is_train:\n            rendered_RT = sample_poses( info[frame_idx]['gt_pose'].astype(np.float32)[None] )\n\n        K = info[frame_idx]['K'].astype(np.float32)\n        RT = info[frame_idx]['gt_pose'].astype(np.float32) #[R,t]\n\n        # evaluation \n        if not self.is_train:\n            if self.pose_cnn_results_test is not None:\n                class_name=self.infos[\"seqs\"][seq_idx]\n\n                if self.init_pose_type == \"PVNET_LINEMOD\":\n                    try:\n                        posecnn_RT = self.pose_cnn_results_test[class_name][frame_idx] # if self.pose_cnn_results_test is not None else np.zeros_like(RT)\n                        #Transformations are needed as the pvnet has a different coordinate system. \n                        posecnn_RT[:3,:3] =  posecnn_RT[:3,:3]@self.blender_to_bop_pose[class_name][:3,:3].T\n                        posecnn_RT[:3,3:] =  -posecnn_RT[:3,:3] @self.blender_to_bop_pose[class_name][:3,3:]  + posecnn_RT[:3,3:] \n                    except:\n                        print(\"Warning: frame_idx is out of the range of self.pose_cnn_results_test!\", flush=True)\n                        posecnn_RT= se3_q2m(self.pose_cnn_results_test_posecnn[class_name][frame_idx]['pose']) #np.zeros_like(RT)\n                elif self.init_pose_type ==\"POSECNN_LINEMOD\":\n                    posecnn_RT= se3_q2m(self.pose_cnn_results_test_posecnn[class_name][frame_idx]['pose']) \n                elif self.init_pose_type == \"PVNET_LINEMOD_OCC\":\n                    try:\n                        posecnn_RT = self.pose_cnn_results_test[class_name][frame_idx].copy()# if self.pose_cnn_results_test is not None else np.zeros_like(RT)\n                        #Transformations are needed as the pvnet has a different coordinate system. \n                        posecnn_RT[:3,:3] =  posecnn_RT[:3,:3]@self.blender_to_bop_pose[class_name][:3,:3].T\n                        posecnn_RT[:3,3:] =  -posecnn_RT[:3,:3] @self.blender_to_bop_pose[class_name][:3,3:]  + posecnn_RT[:3,3:] \n                    except:\n                        # print(frame_idx)\n                        raise\n                else:\n                    raise NotImplementedError \n                \n                rendered_RT = posecnn_RT\n            else:\n                print(\"Warning: fail to load cnn poses!\", flush=True)\n                posecnn_RT = np.zeros_like(RT)\n        else:\n            posecnn_RT = np.zeros_like(RT)\n        \n        #add noise--for testing purpose only, should always be disabled in normal cases \n#         rot_std=0; trans_std=0.04; ang_max=1000;\n#         print(f\"Add pose noises rot_std={rot_std}, trans_std={trans_std}\", flush=True)\n#         rendered_RT=sample_poses(rendered_RT[None], rot_std=rot_std, trans_std=trans_std, ang_max=ang_max) \n\n        # Regularize the matrix to be a valid rotation\n        rendered_RT[:3,:3] = rendered_RT[:3,:3]@ np.linalg.inv(scipy.linalg.sqrtm(rendered_RT[:3,:3].T@rendered_RT[:3,:3]))\n        \n        # model_points = np.fromfile(\n        #     str(model_points_path), dtype=np.float32, count=-1).reshape([-1, self.model_point_dim]) # N x model_point_dim\n        model_points, _,_ = load_obj(str(model_points_path) )\n        model_points = model_points.numpy()\n        \n        visb = model_points[:,-1:]  # N x model_point_dim\n\n        model_point_features=np.ones_like(model_points[:,:1]).astype(np.float32)\n\n\n        rgb =  np.asarray(PIL.Image.open(rgb_path))\n\n        if depth_path.endswith('.npy'):\n            depth = np.load(depth_path) # blender \n        else:\n            depth = cv2.imread(depth_path, -1) /1000.\n\n        if self.is_train and \"LM6d_refine_syn\" in self.root_paths[dataset_idx]: #synthetic data\n            rgb = self.load_random_background(rgb, mask=(depth>0)[...,None].repeat(rgb.shape[-1], axis=-1) )\n\n\n        \n        rgb_rendered =  np.asarray(PIL.Image.open(rgb_noisy_rendered_path)) if rgb_noisy_rendered_path is not None else None\n        depth_rendered = np.asarray(PIL.Image.open(depth_noisy_rendered_path))/1000 if depth_noisy_rendered_path is not None else None #TODO: need check\n\n        ren_mask = render_pointcloud(model_points, rendered_RT[None],K=K[None], render_image_size=rgb.shape[:2] ).squeeze()>0\n        # depth = range_to_depth(depth<1, depth*2, K)\n\n        return {\n            \"class_name\":  self.infos[\"seqs\"][seq_idx], \n            \"idx\": idx,\n            \"model_points\": model_points,\n            \"visibility\": visb,\n            \"model_point_features\":model_point_features,\n            \"image\": rgb,\n            \"depth\": depth,\n            \"mask\": depth>0,\n            \"rendered_image\": rgb_rendered,\n            \"rendered_depth\": depth_rendered,\n            \"K\": K,\n            \"RT\": RT,\n            \"rendered_RT\": rendered_RT.astype(np.float32),\n            \"ren_mask\":ren_mask,\n            \"POSECNN_RT\": posecnn_RT.astype(np.float32), #for test, TODO\n            \"scale\": 1 # model_scale * scale = depth_scale\n        }\n\n\n\n    def __getitem__(self, idx):\n\n        data=self._read_data(idx) \n        try:\n            data_p=self.prep_func(data, rand_rgb_transformer=self.rgb_transformer, find_2d3d_correspondence=self.is_train )\n        except Exception as e: \n            if e.args[0] in [\"Too few correspondences are found!\"] :\n                if isinstance(idx, (tuple, list)):\n                    # idx, seed = idx\n                    idx = [(idx[0]+1)%self.__len__(), idx[1]]\n                else:\n                    idx = (idx+1) %self.__len__()\n                data_p= self.__getitem__(idx )\n            else:\n                raise ValueError\n\n        return data_p\n\n    def __len__(self):\n        return np.sum(self.infos['seq_lengths'])\n"
  },
  {
    "path": "data/preprocess.py",
    "content": "import open3d as o3d\nimport copy\nimport os\n\nimport pathlib\nimport pickle\nimport time\nfrom collections import defaultdict\nfrom functools import partial\n\nimport cv2\nimport numpy as np\nimport quaternion\nfrom skimage import io as imgio\nfrom utils.timer import simple_timer\n\nimport matplotlib.pyplot as plt\nfrom collections.abc import Iterable\nimport torch\nimport torch.nn.functional as F\nimport quaternion\n\nfrom functools import partial\nimport thirdparty.kpconv.cpp_wrappers.cpp_subsampling.grid_subsampling as cpp_subsampling\nimport thirdparty.kpconv.cpp_wrappers.cpp_neighbors.radius_neighbors as cpp_neighbors\nfrom thirdparty.kpconv.lib.timer import Timer\nfrom utils.geometric import range_to_depth, mask_depth_to_point_cloud\nfrom utils.furthest_point_sample import fragmentation_fps\nfrom utils.rand_utils import truncated_normal\n\n\n\ndef merge_batch(batch_list):\n    # [batch][key][seq]->example[key][seq][batch]\n    # Or [batch][key]->example[key][batch]\n    example_merged = defaultdict(list)\n    for example in batch_list:  # batch dim\n        for k, v in example.items():  # key dim\n            # assert isinstance(v, list)\n            if isinstance(v, list):\n                seq_len = len(v)\n                if k not in example_merged:\n                    example_merged[k] = [[] for i in range(seq_len)]\n                for i, vi in enumerate(v):  # seq dim\n                    example_merged[k][i].append(vi)\n\n            else:\n                example_merged[k].append(v)\n\n    ret = {}\n    for key, elems in example_merged.items():\n        if key in ['model_points', \"original_model_points\", 'visibility']:\n            # concat the points of lenghts (N1,N2...) to a longer one with length (N1+N2+...)\n            ret[key] = np.concatenate(elems, axis=0)\n            # record the point numbers for original batches\n            ret['batched_model_point_lengths'] = np.array(\n                [len(p) for p in elems], dtype=np.int32)\n        elif key in ['rand_model_points', ]:\n            # concat the points of lenghts (N1,N2...) to a longer one with length (N1+N2+...)\n            ret[key] = np.concatenate(elems, axis=0)\n            # record the point numbers for original batches\n            ret['batched_rand_model_point_lengths'] = np.array(\n                [len(p) for p in elems], dtype=np.int32)\n        elif key in ['model_point_features']:\n            ret[key] = np.concatenate(elems, axis=0)\n\n        # ['odometry/tq','odometry/RT','odometry/invRT' ]:\n        elif key in ['image', 'depth', 'K', 'RT', 'original_RT' ,'rand_RT', 'correspondences_2d3d', 'scale',  'POSECNN_RT','rendered_image', 'rendered_depth', 'rendered_RT', '3d_keypoint_inds', '3d_keypoints', 'mask', 'ren_mask']: #'depth_coords2d','lifted_points', \n            try:\n                ret[key] = np.stack(elems, axis=0)\n            except:\n                print(key, flush=True)\n                raise\n        elif key == 'metrics':\n            ret[key] = elems\n        else:\n            ret[key] = []\n            for e in elems:\n                ret[key].append(e)\n\n    return ret\n\n\ndef get_correspondences(src_pcd, tgt_pcd, search_voxel_size, K=None, trans=None):\n    if trans is not None:\n        src_pcd.transform(trans)\n    pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd)\n\n    correspondences = []\n    for i, point in enumerate(src_pcd.points):\n        [count, idx, _] = pcd_tree.search_radius_vector_3d(\n            point, search_voxel_size)\n        if K is not None:\n            idx = idx[:K]\n        for j in idx:\n            correspondences.append([i, j])\n\n    correspondences = np.array(correspondences)\n    # correspondences = torch.from_numpy(correspondences)\n    return correspondences\n\n\ndef to_pcd(xyz):\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(xyz)\n    return pcd\n\n\ndef to_tsfm(rot, trans):\n    tsfm = np.eye(4)\n    tsfm[:3, :3] = rot\n    tsfm[:3, 3] = trans.flatten()\n    return tsfm\n\n\ndef CameraIntrinsicUpdate(old_K, aug_param):\n    '''\n    old_K: array of shape (N,3,3), the old camera intrinsic parameters\n    aug_pram: dict, the data augmentation parameters\n    '''\n    aug_type = aug_param['aug_type']\n    assert aug_type in ['crop', 'scale', 'flip']\n\n    new_K = np.copy(old_K)\n    if aug_type == 'crop':\n        cx, cy = aug_param['crop/left_top_corner']  # x,y\n        new_K[..., 0, 2] = new_K[..., 0, 2] - cx\n        new_K[..., 1, 2] = new_K[..., 1, 2] - cy\n    elif aug_type == 'scale':\n        s_x, s_y = aug_param['scale/scale']\n        new_K[..., 0, 0] = s_x * new_K[..., 0, 0]\n        new_K[..., 1, 1] = s_y * new_K[..., 1, 1]\n\n        new_K[..., 0, 2] = s_x * new_K[..., 0, 2]\n        new_K[..., 1, 2] = s_y * new_K[..., 1, 2]\n    elif aug_type == 'flip':\n        w = aug_param['flip/width']\n        # h = aug_param['flip/heigh']\n        new_K[..., 0, 2] = w - new_K[..., 0, 2]  # px' = w-px\n        # new_K[...,1,2] = h- new_K[...,1,2]\n        new_K[..., 0, 0] = - new_K[..., 0, 0]  # fx' = -fx\n\n    return new_K\n\n\ndef crop_transform(images, depths, Ks, crop_param, ):\n    assert(len(images) == len(depths) == len(Ks))\n\n    crop_type = crop_param[\"crop_type\"]\n    assert(crop_type in [\"fixed\", \"center\", \"random\"])\n\n    crop_size = crop_param[\"crop_size\"]\n    iheight, iwidth = images[0].shape[:2]\n\n    if crop_type == \"fixed\":\n        lt_corner = crop_param[\"lt_corner\"]\n        op = transforms.Crop(\n            lt_corner[0], lt_corner[1], crop_size[0], crop_size[1])\n    elif crop_type == \"center\":\n        op = transforms.CenterCrop(crop_size)\n\n        ci, cj, _, _ = op.get_params(images[0], crop_size)\n        lt_corner = ci, cj\n\n    elif crop_type == \"random\":\n        op = transforms.RandomCrop((iheight, iwidth), crop_size)\n\n        lt_corner = op.i, op.j\n\n    for i, _ in enumerate(images):\n        images[i] = op(images[i])\n        depths[i] = op(depths[i])\n\n        Ks[i] = CameraIntrinsicUpdate(Ks[i],\n                                      {\"aug_type\": \"crop\", \"crop/left_top_corner\": (lt_corner[1], lt_corner[0])})\n\n    return images, depths, Ks\n\n\n# def patch_crop(image, depth, mask, K_old, margin_ratio=0.2, output_size=128, offset_ratio=(0,0),bbox=None, mask_depth=True):\ndef patch_crop(image, depth, mask, K_old, margin_ratio=0.2, output_size=128, offset_ratio=(0,0),bbox=None, mask_depth=False):\n    '''\n        image: HxWx3\n        mask: HxW\n        K_old: 3x3\n        offset: (offset_h, offset_w)\n    '''\n\n    H, W, _ = image.shape\n    \n    mask = mask.astype('uint8')*255\n    if bbox is None:\n        _x, _y, _w, _h = cv2.boundingRect(mask)\n    else:\n        _x, _y, _w, _h = bbox[1], bbox[0], bbox[3]-bbox[1], bbox[2]-bbox[0]\n\n    # center = [_x+_w/2, _y+_h/2]\n    center = [_x+_w/2+offset_ratio[1]*_w, _y+_h/2+offset_ratio[0]*_h ]\n\n    L = int(max(_w, _h) * (1+2*margin_ratio))\n    \n    if L<0:\n        #TODO\n        print(mask.sum(), depth.sum(), '!!!', flush=True)\n        L=128\n\n    x = max(0, int(center[0] - L/2))\n    y = max(0, int(center[1] - L/2))\n\n    crop = image[y:y+L, x:x+L]\n    # only keep the ROI depth\n\n    if mask_depth:\n        depth[mask < 1] = 0 # removed by dy at 0810\n    depth_crop = depth[y:y+L, x:x+L]\n    mask_crop = mask[y:y+L, x:x+L]\n\n    \n\n    # w=h=int ((1+2*margin_ratio)*L) # actual crop size\n    w = h = L  # actual crop size\n    # automatically handle the \"out of range\" problem\n    patch = np.zeros([h, w, 3], dtype=image.dtype)\n    # depth_patch = np.ones([h, w], dtype=depth.dtype)\n    depth_patch = np.zeros([h, w], dtype=depth.dtype)\n    mask_patch = np.zeros([h, w], dtype=depth.dtype)\n\n    try:\n        xp = 0\n        yp = 0\n        patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = crop\n        depth_patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = depth_crop\n        mask_patch[xp: xp+crop.shape[0], yp:yp+crop.shape[1]] = mask_crop\n    except:\n        import pdb\n        pdb.set_trace()\n    patch = cv2.resize(patch, (output_size, output_size),\n                       interpolation=cv2.INTER_LINEAR)\n    depth_patch = cv2.resize(\n        depth_patch, (output_size, output_size), interpolation=cv2.INTER_NEAREST)\n    mask_patch = cv2.resize(\n        mask_patch, (output_size, output_size), interpolation=cv2.INTER_NEAREST)\n\n    # update the intrinsic parameters\n    K_new = np.zeros_like(K_old)\n    scale = output_size/L\n    K_new[0, 2] = (K_old[0, 2]-x)*scale\n    K_new[1, 2] = (K_old[1, 2]-y)*scale\n    K_new[0, 0] = K_old[0, 0]*scale\n    K_new[1, 1] = K_old[1, 1]*scale\n    K_new[2, 2] = 1\n\n    # return patch, depth_patch, K_new\n    return patch, depth_patch, mask_patch, K_new\n\n\ndef preprocess_deepim(\n    input_dict,\n    max_points,\n    correspondence_radius,\n    normalize_model=True,\n    rand_transform_model=False,  # False,#True,\n    rand_rgb_transformer=None,\n    image_scale=None,\n    patch_cropper=None,  # func patch_crop(...)\n    \n):\n    output_dict = copy.deepcopy(input_dict)\n\n    ################################### process 3D point clouds ###################################\n\n    if (output_dict['model_points'].shape[0] > max_points):\n        # if(output_dict['model_points'].shape[0] > 20000):\n        idx = np.random.permutation(\n            output_dict['model_points'].shape[0])[:max_points]\n        print(idx, output_dict['model_points'].shape, flush=True)\n        output_dict['model_points'] = output_dict['model_points'][idx]\n        output_dict['model_point_features'] = output_dict['model_point_features'][idx]\n\n    output_dict['original_RT'] = copy.deepcopy(output_dict['RT'])\n    if normalize_model:\n        points = output_dict['model_points']\n        mean = points.mean(axis=0)\n        scope = points.max(axis=0)-points.min(axis=0)\n        points_normalize = (points-mean)/scope.max()\n        # points_normalize.tofile(bin_save_path)\n        # modify the extrinsic parameters\n        output_dict['RT'][:, 3:] = output_dict['RT'][:, :3] @ mean[:,\n                                                                   None] + output_dict['RT'][:, 3:]  # 3x3 @ 3x1 + 3x1\n        # input_dict['RT'][:,:3] *=scope.max()\n        output_dict['scale'] = scope.max()\n        output_dict['original_model_points'] = output_dict['model_points']\n        output_dict['model_points'] = points_normalize\n\n\n    if rand_transform_model:\n        points = output_dict['model_points']\n        rand_quat = np.random.randn(1, 4)\n        rand_quat = rand_quat/np.linalg.norm(rand_quat, axis=-1)\n        rand_rot = quaternion.as_rotation_matrix(\n            quaternion.from_float_array(rand_quat)).squeeze()  # 3x3\n        output_dict['rand_model_points'] = (\n            rand_rot@ points.T).T.astype(np.float32)\n        output_dict['rand_RT'] = copy.deepcopy(output_dict['RT'])\n        # output_dict['RT'][:,:3]@rand_rot.T\n        output_dict['rand_RT'][:, :3] = rand_rot\n        output_dict['rand_RT'][:, 3] = 0\n\n    ################################### process 2D images ###################################\n    # carve out image patches\n    if patch_cropper is not None:\n        ref_mask = output_dict['depth'] > 0\n        output_dict['image'], output_dict['depth'], output_dict['K'] = patch_cropper(\n            output_dict['image'], output_dict['depth'],  ref_mask, output_dict['K'])\n\n        output_dict['rendered_image'], output_dict['rendered_depth'], _ = patch_cropper(\n            output_dict['rendered_image'], output_dict['rendered_depth'],  ref_mask, output_dict['K'].copy() )\n\n    # rescale image\n    if image_scale is not None:\n        output_dict['image'] = cv2.resize(output_dict['image'],\n                                          (output_dict['image'].shape[1]*image_scale,\n                                           output_dict['image'].shape[0]*image_scale),\n                                          interpolation=cv2.INTER_AREA)\n        output_dict['depth'] = cv2.resize(output_dict['depth'],\n                                          (output_dict['depth'].shape[1]*image_scale,\n                                           output_dict['depth'].shape[0]*image_scale),\n                                          interpolation=cv2.INTER_NEAREST)\n        output_dict['K'][:2] = output_dict['K'][:2]*image_scale\n\n    # lift depth\n    depth = output_dict['depth'].squeeze()  # H,W\n    depth_pts, depth_coords2d = mask_depth_to_point_cloud(\n        depth != 0, depth, output_dict['K'])\n    depth_pts = (output_dict['RT'][:, :3].T@(depth_pts.T - output_dict['RT']\n                                             [:, 3:])).T / output_dict['scale']  # transformed to the model frame\n\n    # find 2d-3d correspondences\n    tsfm = np.eye(4)\n    tsfm[:3] = output_dict['RT']\n    model_pcd = output_dict['model_points']\n\n    correspondences_2d3d = get_correspondences(\n        to_pcd(depth_pts), to_pcd(model_pcd),  correspondence_radius, K=5)\n    if len(correspondences_2d3d.shape) < 2 or len(correspondences_2d3d) < 10:\n        print(depth_pts.shape, model_pcd.shape)\n        print(\"correspondences_2d3d.shape:\",\n              correspondences_2d3d.shape, flush=True)\n        # raise ValueError(\"Too few correspondences are found!\")\n        raise Exception(\"Too few correspondences are found!\")\n\n    output_dict['depth_coords2d'] = depth_coords2d\n    output_dict['lifted_points'] = depth_pts\n    # output_dict['correspondences_2d3d'] = np.zeros(1)#correspondences_2d3d\n    output_dict['correspondences_2d3d'] = correspondences_2d3d\n\n    if rand_rgb_transformer is not None:\n        output_dict['image'], _, _ = rand_rgb_transformer(output_dict['image'])\n    # TO TENSOR\n    output_dict['image'] = (output_dict['image'].astype(\n        np.float32)/255.0).transpose([2, 0, 1])  # .mean(axis=0, keepdims=True) # 1,H,W\n    output_dict['depth'] = output_dict['depth'].astype(np.float32)[\n        None]  # 1,H,W\n\n    return output_dict\n\ndef preprocess(\n    input_dict,\n    max_points,\n    correspondence_radius,\n    normalize_model=True,\n    rand_transform_model=False, \n    rand_rgb_transformer=None,\n    image_scale=None,\n    crop_param=None,\n    kp_3d_param=None,\n    use_coords_as_3d_feat=False,\n    find_2d3d_correspondence=True,\n    \n):\n    output_dict = copy.deepcopy(input_dict)\n\n    ################################### process 3D point clouds ###################################\n    if use_coords_as_3d_feat:\n        output_dict['model_point_features'] = output_dict['model_points'][:,:3]\n\n    if (output_dict['model_points'].shape[0] > max_points):\n        # if(output_dict['model_points'].shape[0] > 20000):\n        idx = np.random.permutation(\n            output_dict['model_points'].shape[0])[:max_points]\n        print(idx, output_dict['model_points'].shape, flush=True)\n        output_dict['model_points'] = output_dict['model_points'][idx]\n        output_dict['model_point_features'] = output_dict['model_point_features'][idx]\n\n    output_dict['original_RT'] = copy.deepcopy(output_dict['RT'])\n    output_dict['original_model_points'] = output_dict['model_points']\n    if normalize_model:\n        points = output_dict['model_points']\n        mean = points.mean(axis=0)\n        scope = points.max(axis=0)-points.min(axis=0)\n        points_normalize = (points-mean)/scope.max()\n        # modify the extrinsic parameters\n        output_dict['RT'][:, 3:] = output_dict['RT'][:, :3] @ mean[:,\n                                                                   None] + output_dict['RT'][:, 3:]  # 3x3 @ 3x1 + 3x1\n        output_dict['scale'] = scope.max()\n        output_dict['model_points'] = points_normalize\n\n\n    if rand_transform_model:\n        points = output_dict['model_points']\n        rand_quat = np.random.randn(1, 4)\n        rand_quat = rand_quat/np.linalg.norm(rand_quat, axis=-1)\n        rand_rot = quaternion.as_rotation_matrix(\n            quaternion.from_float_array(rand_quat)).squeeze()  # 3x3\n        output_dict['rand_model_points'] = (\n            rand_rot@ points.T).T.astype(np.float32)\n        output_dict['rand_RT'] = copy.deepcopy(output_dict['RT'])\n        output_dict['rand_RT'][:, :3] = rand_rot\n        output_dict['rand_RT'][:, 3] = 0\n\n    ################################### process 2D images ###################################\n    #crop image\n    if crop_param is not None:# and output_dict['mask'].sum()>0:\n        #without random cropping\n        if not crop_param.rand_crop: \n            if crop_param.get(\"crop_with_init_pose\", False):\n                # bbox= output_dict.get('bbox', None)\n                bbox=None\n                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['ren_mask'],\n                                K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size,  bbox=bbox\n                                                )\n            elif crop_param.get(\"crop_with_rand_bbox_shift\", True): \n                bbox= output_dict.get('bbox', None)\n                # offset_ratio= [truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio, truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio] \n                offset_ratio= [truncated_normal(0,0.33,-1,1)*1, truncated_normal(0,0.33,-1,1)*1] \n                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],\n                                                    K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size, offset_ratio=offset_ratio, bbox=output_dict.get('bbox', None) \n                                                )\n            else:\n                bbox= output_dict.get('bbox', None)\n                output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],\n                                                    K_old=output_dict['K'], margin_ratio=crop_param.margin_ratio, output_size=crop_param.output_size, bbox=output_dict.get('bbox', None) \n                                                )\n        else:\n            margin_ratio= truncated_normal(0.5, 0.5, 0, 1) *crop_param.max_rand_margin_ratio\n            offset_ratio= [truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio, truncated_normal(0,0.5,-1,1)*crop_param.max_rand_offset_ratio] \n            output_dict['image'], output_dict['depth'], output_dict['mask'], output_dict['K'] = patch_crop(output_dict['image'], output_dict['depth'], mask=output_dict['mask'],\n                                                K_old=output_dict['K'], margin_ratio=margin_ratio, output_size=crop_param.output_size, offset_ratio=offset_ratio,  bbox=output_dict.get('bbox', None) \n                                              )\n            \n    # rescale image\n    if image_scale is not None:\n        output_dict['image'] = cv2.resize(output_dict['image'],\n                                          (output_dict['image'].shape[1]*image_scale,\n                                           output_dict['image'].shape[0]*image_scale),\n                                          interpolation=cv2.INTER_AREA)\n        output_dict['depth'] = cv2.resize(output_dict['depth'],\n                                          (output_dict['depth'].shape[1]*image_scale,\n                                           output_dict['depth'].shape[0]*image_scale),\n                                          interpolation=cv2.INTER_NEAREST)\n        output_dict['K'][:2] = output_dict['K'][:2]*image_scale\n\n    # lift depths\n    depth = output_dict['depth'].squeeze()  # H,W\n    depth_pts, depth_coords2d = mask_depth_to_point_cloud(\n        depth != 0, depth, output_dict['K'])\n\n    depth_pts = (output_dict['RT'][:, :3].T@(depth_pts.T - output_dict['RT']\n                                             [:, 3:])).T / output_dict['scale']  # transformed to the model frame\n\n    # find 2d-3d correspondences\n    if find_2d3d_correspondence:\n        tsfm = np.eye(4)\n        tsfm[:3] = output_dict['RT']\n        model_pcd = output_dict['model_points']\n        correspondences_2d3d = get_correspondences(\n            to_pcd(depth_pts), to_pcd(model_pcd),  correspondence_radius, K=5)\n        if len(correspondences_2d3d.shape) < 2 or len(correspondences_2d3d) < 10:# or ( \"mask\" in output_dict and output_dict['mask'].sum()<10 ) :\n            print(depth_pts.shape, model_pcd.shape)\n            print(\"correspondences_2d3d.shape:\",\n                  correspondences_2d3d.shape, flush=True)\n            raise Exception(\"Too few correspondences are found!\")\n\n        output_dict['depth_coords2d'] = depth_coords2d\n        output_dict['lifted_points'] = depth_pts\n        output_dict['correspondences_2d3d'] = correspondences_2d3d\n    else:\n        output_dict['depth_coords2d'] = depth_coords2d\n        output_dict['lifted_points'] = depth_pts\n        output_dict['correspondences_2d3d'] = np.zeros([10,2], dtype=np.int64) \n\n\n    if rand_rgb_transformer is not None:\n        output_dict['image'], _, _ = rand_rgb_transformer(output_dict['image'])\n    # TO TENSORs\n    output_dict['image'] = (output_dict['image'].astype(\n        np.float32)/255.0).transpose([2, 0, 1])  # .mean(axis=0, keepdims=True) # 1,H,W\n    output_dict['depth'] = output_dict['depth'].astype(np.float32)[\n        None]  # 1,H,W\n\n    return output_dict\n\ndef batch_grid_subsampling_kpconv(points, batches_len, features=None, labels=None, sampleDl=0.1, max_p=0, verbose=0, random_grid_orient=True):\n    \"\"\"\n    CPP wrapper for a grid subsampling (method = barycenter for points and features)\n    \"\"\"\n    if (features is None) and (labels is None):\n        s_points, s_len = cpp_subsampling.subsample_batch(points,\n                                                          batches_len,\n                                                          sampleDl=sampleDl,\n                                                          max_p=max_p,\n                                                          verbose=verbose)\n        return torch.from_numpy(s_points), torch.from_numpy(s_len)\n\n    elif (labels is None):\n        s_points, s_len, s_features = cpp_subsampling.subsample_batch(points,\n                                                                      batches_len,\n                                                                      features=features,\n                                                                      sampleDl=sampleDl,\n                                                                      max_p=max_p,\n                                                                      verbose=verbose)\n        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_features)\n\n    elif (features is None):\n        s_points, s_len, s_labels = cpp_subsampling.subsample_batch(points,\n                                                                    batches_len,\n                                                                    classes=labels,\n                                                                    sampleDl=sampleDl,\n                                                                    max_p=max_p,\n                                                                    verbose=verbose)\n        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_labels)\n\n    else:\n        s_points, s_len, s_features, s_labels = cpp_subsampling.subsample_batch(points,\n                                                                                batches_len,\n                                                                                features=features,\n                                                                                classes=labels,\n                                                                                sampleDl=sampleDl,\n                                                                                max_p=max_p,\n                                                                                verbose=verbose)\n        return torch.from_numpy(s_points), torch.from_numpy(s_len), torch.from_numpy(s_features), torch.from_numpy(s_labels)\n\n\ndef batch_neighbors_kpconv(queries, supports, q_batches, s_batches, radius, max_neighbors):\n    \"\"\"\n    Computes neighbors for a batch of queries and supports, apply radius search\n    :param queries: (N1, 3) the query points\n    :param supports: (N2, 3) the support points\n    :param q_batches: (B) the list of lengths of batch elements in queries\n    :param s_batches: (B)the list of lengths of batch elements in supports\n    :param radius: float32\n    :return: neighbors indices\n    \"\"\"\n\n    neighbors = cpp_neighbors.batch_query(\n        queries, supports, q_batches, s_batches, radius=radius)\n    # print(\"neighbors.shape\" , neighbors.shape, queries.shape,flush=True)\n    if max_neighbors > 0:\n        return torch.from_numpy(neighbors[:, :max_neighbors])\n    else:\n        return torch.from_numpy(neighbors)\n\n\ndef collate_fn_descriptor(list_data, config, neighborhood_limits):\n    ret = merge_batch(list_data)\n\n    batched_points = torch.from_numpy(ret['model_points'])\n    batched_lengths = torch.from_numpy(ret['batched_model_point_lengths'])\n    batched_features = torch.from_numpy(ret['model_point_features'])\n\n    if ret.get('rand_model_points', None) is not None:\n        batched_rand_points = torch.from_numpy(ret['rand_model_points'])\n        batched_rand_lengths = torch.from_numpy(\n            ret['batched_rand_model_point_lengths'])\n\n        batched_points = torch.cat(\n            [batched_points, batched_rand_points], dim=0)\n        batched_lengths = torch.cat(\n            [batched_lengths, batched_rand_lengths], dim=0)\n        batched_features = torch.cat(\n            [batched_features, batched_features], dim=0)\n\n    # Starting radius of convolutions\n    r_normal = config.first_subsampling_dl * config.conv_radius\n    # Starting layer\n    layer_blocks = []\n    layer = 0\n\n    # Lists of inputs\n    input_points = []\n    input_neighbors = []\n    input_pools = []\n    input_upsamples = []\n    input_batches_len = []\n    timer = Timer()\n    for block_i, block in enumerate(config.architecture):\n        # Stop when meeting a global pooling or upsampling\n        if 'global' in block or 'upsample' in block:\n            break\n\n        # Get all blocks of the layer\n        if not ('pool' in block or 'strided' in block):\n            layer_blocks += [block]\n            if block_i < len(config.architecture) - 1 and not ('upsample' in config.architecture[block_i + 1]):\n                continue\n\n        # Convolution neighbors indices\n        # *****************************\n\n        if layer_blocks:\n            # Convolutions are done in this layer, compute the neighbors with the good radius\n            if np.any(['deformable' in blck for blck in layer_blocks[:-1]]):\n                r = r_normal * config.deform_radius / config.conv_radius\n            else:\n                r = r_normal\n            conv_i = batch_neighbors_kpconv(\n                batched_points, batched_points, batched_lengths, batched_lengths, r, neighborhood_limits[layer])\n\n        else:\n            # This layer only perform pooling, no neighbors required\n            conv_i = torch.zeros((0, 1), dtype=torch.int64)\n\n        # Pooling neighbors indices\n        # *************************\n\n        if 'pool' in block or 'strided' in block:\n\n            # New subsampling length\n            dl = 2 * r_normal / config.conv_radius\n\n            # Subsampled points\n            pool_p, pool_b = batch_grid_subsampling_kpconv(\n                batched_points, batched_lengths, sampleDl=dl)\n\n            # Radius of pooled neighbors\n            if 'deformable' in block:\n                r = r_normal * config.deform_radius / config.conv_radius\n            else:\n                r = r_normal\n\n            # Subsample indices\n            pool_i = batch_neighbors_kpconv(\n                pool_p, batched_points, pool_b, batched_lengths, r, neighborhood_limits[layer])\n\n            # Upsample indices (with the radius of the next layer to keep wanted density)\n            up_i = batch_neighbors_kpconv(\n                batched_points, pool_p, batched_lengths, pool_b, 2 * r, neighborhood_limits[layer])\n\n        else:\n            # No pooling in the end of this layer, no pooling indices required\n            pool_i = torch.zeros((0, 1), dtype=torch.int64)\n            pool_p = torch.zeros((0, 3), dtype=torch.float32)\n            pool_b = torch.zeros((0,), dtype=torch.int64)\n            up_i = torch.zeros((0, 1), dtype=torch.int64)\n\n        # Updating input lists\n        input_points += [batched_points.float()]\n        input_neighbors += [conv_i.long()]\n        input_pools += [pool_i.long()]\n        input_upsamples += [up_i.long()]\n        input_batches_len += [batched_lengths]\n\n        # New points for next layer\n        batched_points = pool_p\n        batched_lengths = pool_b\n\n        # Update radius and reset blocks\n        r_normal *= 2\n        layer += 1\n        layer_blocks = []\n\n    ###############\n    # Return inputs\n    ###############\n    dict_inputs = {\n        \"idx\": ret[\"idx\"],\n        'model_points': input_points,\n        'visibility': torch.from_numpy(ret['visibility']),\n        'neighbors': input_neighbors,\n        'pools': input_pools,\n        'upsamples': input_upsamples,\n        'model_point_features': batched_features.float(),\n        'stack_lengths': input_batches_len,\n        'image': torch.from_numpy(ret['image']),\n        'depth': torch.from_numpy(ret['depth']),\n        'mask': torch.from_numpy(ret['mask']),\n        'ren_mask': torch.from_numpy(ret['ren_mask']),\n        'K': torch.from_numpy(ret['K']),\n        'RT': torch.from_numpy(ret['RT']),\n        'original_RT': torch.from_numpy(ret['original_RT']),\n        'POSECNN_RT': torch.from_numpy(ret.get('POSECNN_RT', np.zeros_like(ret['RT']) ) ),\n        'rand_RT': torch.from_numpy(ret.get('rand_RT', np.zeros_like(ret['RT']))),\n        # \"lifted_points\": torch.from_numpy(ret['lifted_points']),\n        \"lifted_points\": [torch.from_numpy(d) for d in ret['lifted_points'] ] ,\n        # 'depth_coords2d': torch.from_numpy(ret['depth_coords2d']),\n        'depth_coords2d': [torch.from_numpy(d) for d in ret['depth_coords2d']],\n        \"correspondences_2d3d\": torch.from_numpy(ret['correspondences_2d3d']),\n        \"original_model_points\": torch.from_numpy(ret['original_model_points']),\n        \"class_name\": ret['class_name'],\n        \"3d_keypoint_inds\": torch.from_numpy(ret['3d_keypoint_inds']),\n        \"3d_keypoints\": torch.from_numpy(ret['3d_keypoints'] ) \n    }\n\n    return dict_inputs\n\n\ndef collate_fn_descriptor_deepim(list_data, config, neighborhood_limits):\n    ret = merge_batch(list_data)\n\n    batched_points = torch.from_numpy(ret['model_points'])\n    batched_lengths = torch.from_numpy(ret['batched_model_point_lengths'])\n    batched_features = torch.from_numpy(ret['model_point_features'])\n    \n\n    if ret.get('rand_model_points', None) is not None:\n        # torch.from_numpy(np.concatenate(batched_points_list, axis=0))\n        batched_rand_points = torch.from_numpy(ret['rand_model_points'])\n        # torch.from_numpy(np.concatenate(batched_points_list, axis=0))\n        batched_rand_lengths = torch.from_numpy(\n            ret['batched_rand_model_point_lengths'])\n\n        batched_points = torch.cat(\n            [batched_points, batched_rand_points], dim=0)\n        batched_lengths = torch.cat(\n            [batched_lengths, batched_rand_lengths], dim=0)\n        batched_features = torch.cat(\n            [batched_features, batched_features], dim=0)\n\n    # Starting radius of convolutions\n    r_normal = config.first_subsampling_dl * config.conv_radius\n    # Starting layer\n    layer_blocks = []\n    layer = 0\n\n    # Lists of inputs\n    input_points = []\n    input_neighbors = []\n    input_pools = []\n    input_upsamples = []\n    input_batches_len = []\n    timer = Timer()\n    for block_i, block in enumerate(config.architecture):\n        # timer.tic()\n\n        # Stop when meeting a global pooling or upsampling\n        if 'global' in block or 'upsample' in block:\n            break\n\n        # Get all blocks of the layer\n        if not ('pool' in block or 'strided' in block):\n            layer_blocks += [block]\n            if block_i < len(config.architecture) - 1 and not ('upsample' in config.architecture[block_i + 1]):\n                continue\n\n        # Convolution neighbors indices\n        # *****************************\n\n        if layer_blocks:\n            # Convolutions are done in this layer, compute the neighbors with the good radius\n            if np.any(['deformable' in blck for blck in layer_blocks[:-1]]):\n                r = r_normal * config.deform_radius / config.conv_radius\n            else:\n                r = r_normal\n            conv_i = batch_neighbors_kpconv(\n                batched_points, batched_points, batched_lengths, batched_lengths, r, neighborhood_limits[layer])\n\n        else:\n            # This layer only perform pooling, no neighbors required\n            conv_i = torch.zeros((0, 1), dtype=torch.int64)\n\n        # Pooling neighbors indices\n        # *************************\n\n        # If end of layer is a pooling operation\n        if 'pool' in block or 'strided' in block:\n\n            # New subsampling length\n            dl = 2 * r_normal / config.conv_radius\n\n            # Subsampled points\n            pool_p, pool_b = batch_grid_subsampling_kpconv(\n                batched_points, batched_lengths, sampleDl=dl)\n\n            # Radius of pooled neighbors\n            if 'deformable' in block:\n                r = r_normal * config.deform_radius / config.conv_radius\n            else:\n                r = r_normal\n\n            # Subsample indices\n            pool_i = batch_neighbors_kpconv(\n                pool_p, batched_points, pool_b, batched_lengths, r, neighborhood_limits[layer])\n\n            # Upsample indices (with the radius of the next layer to keep wanted density)\n            up_i = batch_neighbors_kpconv(\n                batched_points, pool_p, batched_lengths, pool_b, 2 * r, neighborhood_limits[layer])\n\n        else:\n            # No pooling in the end of this layer, no pooling indices required\n            pool_i = torch.zeros((0, 1), dtype=torch.int64)\n            pool_p = torch.zeros((0, 3), dtype=torch.float32)\n            pool_b = torch.zeros((0,), dtype=torch.int64)\n            up_i = torch.zeros((0, 1), dtype=torch.int64)\n\n        # Updating input lists\n        input_points += [batched_points.float()]\n        input_neighbors += [conv_i.long()]\n        input_pools += [pool_i.long()]\n        input_upsamples += [up_i.long()]\n        input_batches_len += [batched_lengths]\n\n        # New points for next layer\n        batched_points = pool_p\n        batched_lengths = pool_b\n\n        # Update radius and reset blocks\n        r_normal *= 2\n        layer += 1\n        layer_blocks = []\n\n        # timer.toc()\n    ###############\n    # Return inputs\n    ###############\n    dict_inputs = {\n        \"idx\": ret[\"idx\"],\n        'model_points': input_points,\n        'visibility': torch.from_numpy(ret['visibility']),\n        'neighbors': input_neighbors,\n        'pools': input_pools,\n        'upsamples': input_upsamples,\n        'model_point_features': batched_features.float(),\n        'stack_lengths': input_batches_len,\n        'image': torch.from_numpy(ret['image']),\n        'depth': torch.from_numpy(ret['depth']),\n        \"ren_mask\": torch.from_numpy(ret['ren_mask']),\n        'K': torch.from_numpy(ret['K']),\n        'RT': torch.from_numpy(ret['RT']),\n        'original_RT': torch.from_numpy(ret['original_RT']),\n        'rendered_RT': torch.from_numpy(ret['rendered_RT']) if ret.get('rendered_RT', None) is not None else None ,\n        'POSECNN_RT': torch.from_numpy(ret.get('POSECNN_RT', np.zeros_like(ret['RT']) ) ),\n        # TODO\n        'rand_RT': torch.from_numpy(ret.get('rand_RT', np.zeros_like(ret['RT']))),\n        # \"lifted_points\": torch.from_numpy(ret['lifted_points']),\n        \"lifted_points\": [torch.from_numpy(d) for d in ret['lifted_points'] ] ,\n        # 'depth_coords2d': torch.from_numpy(ret['depth_coords2d']),\n        'depth_coords2d': [torch.from_numpy(d) for d in ret['depth_coords2d']],\n        \"correspondences_2d3d\": torch.from_numpy(ret['correspondences_2d3d']),\n        \"original_model_points\": torch.from_numpy(ret['original_model_points']),\n        \"class_name\": ret['class_name'],\n    }\n\n    return dict_inputs\n\n\ndef calibrate_neighbors(dataset, config, collate_fn, keep_ratio=0.8, samples_threshold=2000):\n    timer = Timer()\n    last_display = timer.total_time\n\n    # From config parameter, compute higher bound of neighbors number in a neighborhood\n    hist_n = int(np.ceil(4 / 3 * np.pi * (config.deform_radius + 1) ** 3))\n    neighb_hists = np.zeros((config.num_layers, hist_n), dtype=np.int32)\n\n    # Get histogram of neighborhood sizes i in 1 epoch max.\n    for i in range(len(dataset)):\n        timer.tic()\n\n        batched_input = collate_fn(\n            [dataset[i]], config, neighborhood_limits=[hist_n] * 5)\n        # update histogram\n        counts = [torch.sum(neighb_mat < neighb_mat.shape[0], dim=1).numpy()\n                  for neighb_mat in batched_input['neighbors']]\n        \n        hists = [np.bincount(c, minlength=hist_n)[:hist_n] for c in counts]\n        neighb_hists += np.vstack(hists)\n        timer.toc()\n\n        if timer.total_time - last_display > 0.1:\n            last_display = timer.total_time\n            print(f\"Calib Neighbors {i:08d}: timings {timer.total_time:4.2f}s\")\n\n        if np.min(np.sum(neighb_hists, axis=1)) > samples_threshold:\n            break\n\n    cumsum = np.cumsum(neighb_hists.T, axis=0)\n    percentiles = np.sum(cumsum < (keep_ratio * cumsum[hist_n - 1, :]), axis=0)\n\n    neighborhood_limits = percentiles\n    print('\\n')\n\n    return neighborhood_limits\n\n\ndef get_dataloader(dataset, kpconv_config, batch_size=1, num_workers=4, shuffle=True, sampler=None, neighborhood_limits=None):\n    if neighborhood_limits is None:\n        # neighborhood_limits = calibrate_neighbors(dataset, dataset.config, collate_fn=collate_fn_descriptor)\n        neighborhood_limits = calibrate_neighbors(\n            dataset, kpconv_config, collate_fn=collate_fn_descriptor)\n    print(\"neighborhood:\", neighborhood_limits)\n    dataloader = torch.utils.data.DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        num_workers=num_workers,\n        # https://discuss.pytorch.org/t/supplying-arguments-to-collate-fn/25754/4\n        collate_fn=partial(collate_fn_descriptor, config=kpconv_config,\n                           neighborhood_limits=neighborhood_limits),\n        sampler=sampler,\n        drop_last=False\n    )\n    return dataloader, neighborhood_limits\n\ndef get_dataloader_deepim(dataset, kpconv_config, batch_size=1, num_workers=4, shuffle=True, sampler=None, neighborhood_limits=None):\n    if neighborhood_limits is None:\n        neighborhood_limits = calibrate_neighbors(\n            dataset, kpconv_config, collate_fn=collate_fn_descriptor_deepim)\n    print(\"Neighborhood:\", neighborhood_limits)\n    dataloader = torch.utils.data.DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        num_workers=num_workers,\n        # https://discuss.pytorch.org/t/supplying-arguments-to-collate-fn/25754/4\n        collate_fn=partial(collate_fn_descriptor_deepim, config=kpconv_config,\n                           neighborhood_limits=neighborhood_limits),\n        sampler=sampler,\n        drop_last=False\n    )\n    return dataloader, neighborhood_limits\n\nif __name__ == '__main__':\n    pass\n"
  },
  {
    "path": "data/transforms.py",
    "content": "import numpy as np\nimport random\nimport torch\nimport torchvision\nfrom torchvision.transforms import functional as F\nimport cv2\nfrom PIL import Image\n\n\nclass Compose(object):\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, img, kpts=None, mask=None):\n        for t in self.transforms:\n            img, kpts, mask = t(img, kpts, mask)\n        return img, kpts, mask\n\n    def __repr__(self):\n        format_string = self.__class__.__name__ + \"(\"\n        for t in self.transforms:\n            format_string += \"\\n\"\n            format_string += \"    {0}\".format(t)\n        format_string += \"\\n)\"\n        return format_string\n\n\nclass ToTensor(object):\n\n    def __call__(self, img, kpts, mask):\n        return np.asarray(img).astype(np.float32) / 255., kpts, mask\n\n\nclass Normalize(object):\n\n    def __init__(self, mean, std, to_bgr=True):\n        self.mean = mean\n        self.std = std\n        self.to_bgr = to_bgr\n\n    def __call__(self, img, kpts, mask):\n        img -= self.mean\n        img /= self.std\n        if self.to_bgr:\n            img = img.transpose(2, 0, 1).astype(np.float32)\n        return img, kpts, mask\n\n\nclass ColorJitter(object):\n\n    def __init__(self,\n                 brightness=None,\n                 contrast=None,\n                 saturation=None,\n                 hue=None,\n                 ):\n        self.color_jitter = torchvision.transforms.ColorJitter(\n            brightness=brightness,\n            contrast=contrast,\n            saturation=saturation,\n            hue=hue,)\n\n    def __call__(self, image, kpts, mask):\n        image = np.asarray(self.color_jitter(Image.fromarray(np.ascontiguousarray(image, np.uint8))))\n        return image, kpts, mask\n\n\nclass RandomBlur(object):\n\n    def __init__(self, prob=0.5):\n        self.prob = prob\n\n    def __call__(self, image, kpts, mask):\n        if random.random() < self.prob:\n            sigma = np.random.choice([3, 5, 7, 9])\n            image = cv2.GaussianBlur(image, (sigma, sigma), 0)\n        return image, kpts, mask\n\n\ndef make_transforms(cfg, is_train):\n    if is_train is True:\n        transform = Compose(\n            [\n                RandomBlur(0.5),\n                ColorJitter(0.1, 0.1, 0.05, 0.05),\n                # ToTensor(),\n                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n            ]\n        )\n    else:\n        transform = Compose(\n            [\n                # ToTensor(),\n                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n            ]\n        )\n\n    return transform\n"
  },
  {
    "path": "data/ycb/basic.py",
    "content": "import mmcv \nbop_ycb_idx2class={\n        1: '002_master_chef_can', \n        2: '003_cracker_box',\n        3: '004_sugar_box', \n        4: '005_tomato_soup_can',\n        5: '006_mustard_bottle',\n        6: '007_tuna_fish_can',\n        7: '008_pudding_box', \n        8: '009_gelatin_box',\n        9: '010_potted_meat_can', \n        10: '011_banana', \n        11: '019_pitcher_base', \n        12: '021_bleach_cleanser', \n        13: '024_bowl', \n        14: '025_mug', \n        15: '035_power_drill',\n        16: '036_wood_block',\n        17: '037_scissors', \n        18: '040_large_marker',\n        19: '051_large_clamp',\n        20: '052_extra_large_clamp',\n        21: '061_foam_brick', \n    }\nbop_ycb_class2idx = dict([[bop_ycb_idx2class[k],k ] for k in bop_ycb_idx2class.keys() ])\n\n\n"
  },
  {
    "path": "doc/prepare_data.md",
    "content": "# Data Preparation Tips\nAll the related data for data preparation can be downloaded [here](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155139432_link_cuhk_edu_hk/EoXnZ96Tuy9PpYlZCvDN8vUBPdP1lP-PWQWiZH2KtIQoaQ?e=lpE472). You could download them first and then follow the instructions below for data preparation. \n\n\n\n## Download Datasets \nFirst, the following dataset need to be downloaded and extracted to the folder *EXPDATA/* \n\n[LINEMOD](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EYFaYrk0kcdBgC6WMtLJqP0B9Ar0_Nff9qhI2Cs95qDbdA?e=yYxexC)\n\n[LINEMOD_OCC_TEST](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EUKcRnwyy9RGu2ASwA3QDXsBnMRrFP-U4X4Eqq-g_MhmIQ?e=hv6H2s)\n\n## Synthetic Data Generation\n\nThe preprocessed data following [DeepIM](https://github.com/liyi14/mx-DeepIM) and [PVNet](https://github.com/zju3dv/pvnet-rendering) can be downloaded from [LM6d_converted](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EYFaYrk0kcdBgC6WMtLJqP0B9Ar0_Nff9qhI2Cs95qDbdA) and [raw_data](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/ESSFXi_7qs1AgNmty7_9y4AB8ffFsGJWOC3ikgD5BIeXHQ?e=qOmvds). \nAfter downloading, you should put the downloaded files into the folder *EXPDATA/* (lying in the repository's root directory). \nTo create occluded objects during training, we follow [PVNet](https://github.com/zju3dv/pvnet-rendering) to randomly create occlusions. \nYou could run the following scripts to transform the data format for our dataloader. \n```\n    bash scripts/run_dataformatter.sh\n```\nThe command above will automatically save the formatted data into *EXPDATA/*. \n\n## Download the Object CAD Models\nYou also need to download the [object models](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EQScZuLrkPNPmN4eO3kePaUBjOe92EvbKb7kGJk2vKz-bA?e=8McAdh) and put the extracted folder *models* into *./EXPDATA/LM6d_converted/. \n\n## Download Background Images\n[Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) need to be downloaded to folder *EXPDATA/*. These images will be necessary for the random background generation for training. \n\n## Download Initial Poses \nThe initial poses estimated by PoseCNN and PVNet can be downloaded from [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155139432_link_cuhk_edu_hk/EQh5y0M_zHVMnbVszjEviCUBNAX_22MFN26Msa48XlJ5MQ?e=rfhT7k). \nThe initial pose folder also should be put into the folder  *EXPDATA/*\n\n## Generate the Information Files\nRun the following script to generate the info files, which is put into the folder *EXPDATA/data_info/*\n\n```\nbash scripts/run_datainfo_generation.sh\n```\n\n\nAfter the the data preparation, the expected directory structure should be \n\n\n```\n./EXPDATA\n    |──LM6d_converted \n    |        |──LM6d_refine \n    |        |──LM6d_refine_syn\n    |        └──models\n    |──LINEMOD\n    |        └──fuse_formatted\n    |──lmo\n    |──VOCdevkit\n    |──raw_data\n    |──init_poses\n    └──data_info\n```\n\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04\n\nRUN apt-key del 7fa2af80\nRUN apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub\nRUN rm /etc/apt/sources.list.d/cuda.list\nRUN rm /etc/apt/sources.list.d/nvidia-ml.list\n\n# Dependencies for glvnd and X11.\nRUN apt-get update \nRUN  apt-get install -y -qq --no-install-recommends \\\n    libglvnd0 \\\n    libgl1 \\\n    libglx0 \\\n    libegl1 \\\n    libxext6 \\\n    libx11-6 \\\n  && rm -rf /var/lib/apt/lists/*\n# Env vars for the nvidia-container-runtime.\nENV NVIDIA_VISIBLE_DEVICES all\nENV NVIDIA_DRIVER_CAPABILITIES graphics,utility,compute\n\n#env vars for cuda\nENV CUDA_HOME /usr/local/cuda\n\n#install miniconda\nRUN apt-get update --fix-missing && \\\n    apt-get install -y wget bzip2 ca-certificates curl git && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists/*\n\nRUN wget --quiet https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O ~/miniconda.sh && \\\n    /bin/bash ~/miniconda.sh -b -p /opt/miniconda3 && \\\n    rm ~/miniconda.sh && \\\n    /opt/miniconda3/bin/conda clean -tipsy && \\\n    ln -s /opt/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \\\n    echo \". /opt/miniconda3/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate base\" >> ~/.bashrc && \\\n    echo \"conda deactivate && conda activate py37\" >> ~/.bashrc\n\n#https://blog.csdn.net/Mao_Jonah/article/details/89502380\nCOPY freeze.yml freeze.yml\nRUN /opt/miniconda3/bin/conda env create -n py37 -f freeze.yml\n\n# WORKDIR /tmp/\n# COPY config.jupyter.tar config.jupyter.tar\n# RUN tar -xvf config.jupyter.tar -C /root/\n\n#install apex\nENV TORCH_CUDA_ARCH_LIST \"6.0 6.2 7.0 7.2\"\n# make sure we don't overwrite some existing directory called \"apex\"\nWORKDIR /tmp/unique_for_apex\n# uninstall Apex if present, twice to make absolutely sure :)\nRUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :\nRUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :\n# SHA is something the user can touch to force recreation of this Docker layer,\n# and therefore force cloning of the latest version of Apex\nRUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git\nWORKDIR /tmp/unique_for_apex/apex\nRUN /opt/miniconda3/envs/py37/bin/pip3 install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" .\n#install pytorch3d \n# RUN /opt/miniconda3/envs/py37/bin/pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py37_cu102_pyt171/download.html\n# RUN /opt/miniconda3/envs/py37/bin/pip install \"git+https://github.com/facebookresearch/pytorch3d.git\"\n# RUN /opt/miniconda3/bin/conda install pytorch3d==0.5.0 -c pytorch3d -n py37\n\n\n\n#other pkgs\nRUN apt-get update \\\n  && apt-get install -y -qq --no-install-recommends \\\n  cmake build-essential vim xvfb unzip tmux psmisc  \\\n  libx11-dev libassimp-dev \\\n  mesa-common-dev freeglut3-dev \\\n  rsync \\\n  && apt-get clean \\\n  && rm -rf /var/lib/apt/lists/*\n\n#create some directories\nRUN mkdir -p /home/RNNPose\n\nEXPOSE 8887 8888 8889 10000 10001 10002 \nWORKDIR /home/RNNPose\n\n"
  },
  {
    "path": "docker/freeze.yml",
    "content": "name: py37_tmp\nchannels:\n  - pytorch\n  - pytorch3d\n  - open3d-admin\n  - bottler\n  - iopath\n  - fvcore\n  - conda-forge\n  - defaults\ndependencies:\n  - pytorch3d=0.5.0\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=4.5=1_gnu\n  - anyio=2.2.0=py37h06a4308_1\n  - argon2-cffi=20.1.0=py37h27cfd23_1\n  - async_generator=1.10=py37h28b3542_0\n  - attrs=21.2.0=pyhd3eb1b0_0\n  - babel=2.9.1=pyhd3eb1b0_0\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bleach=3.3.0=pyhd3eb1b0_0\n  - brotlipy=0.7.0=py37h27cfd23_1003\n  - ca-certificates=2021.5.30=ha878542_0\n  - certifi=2021.5.30=py37h89c1867_0\n  - cffi=1.14.5=py37h261ae71_0\n  - charset-normalizer=2.0.4=pyhd3eb1b0_0\n  - cryptography=3.4.7=py37hd23ed53_0\n  - cudatoolkit=10.2.89=h8f6ccaa_8\n  - cycler=0.10.0=py37_0\n  - dbus=1.13.18=hb2f20db_0\n  - defusedxml=0.7.1=pyhd3eb1b0_0\n  - entrypoints=0.3=py37_0\n  - expat=2.4.1=h2531618_2\n  - fontconfig=2.13.1=h6c09931_0\n  - freetype=2.10.4=h5ab3b9f_0\n  - fvcore=0.1.5.post20210825=py37\n  - glib=2.68.2=h36276a3_0\n  - gst-plugins-base=1.14.0=h8213a91_2\n  - gstreamer=1.14.0=h28cd5cc_2\n  - icu=58.2=he6710b0_3\n  - idna=3.2=pyhd3eb1b0_0\n  - importlib-metadata=3.10.0=py37h06a4308_0\n  - importlib_metadata=3.10.0=hd3eb1b0_0\n  - intel-openmp=2021.2.0=h06a4308_610\n  - iopath=0.1.9=py37\n  - ipykernel=5.3.4=py37h5ca1d4c_0\n  - ipython=7.22.0=py37hb070fc8_0\n  - ipython_genutils=0.2.0=pyhd3eb1b0_1\n  - ipywidgets=7.6.3=pyhd3eb1b0_1\n  - jedi=0.17.0=py37_0\n  - jinja2=3.0.0=pyhd3eb1b0_0\n  - joblib=1.0.1=pyhd3eb1b0_0\n  - jpeg=9b=h024ee3a_2\n  - json5=0.9.6=pyhd3eb1b0_0\n  - jsonschema=3.2.0=py_2\n  - jupyter=1.0.0=py37h89c1867_6\n  - jupyter_client=6.1.12=pyhd3eb1b0_0\n  - jupyter_console=6.4.0=pyhd8ed1ab_0\n  - jupyter_core=4.7.1=py37h06a4308_0\n  - jupyter_server=1.4.1=py37h06a4308_0\n  - jupyterlab=3.0.16=pyhd8ed1ab_0\n  - jupyterlab_pygments=0.1.2=py_0\n  - jupyterlab_server=2.7.1=pyhd3eb1b0_0\n  - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1\n  - kiwisolver=1.3.1=py37h2531618_0\n  - kornia=0.5.3=pyhd8ed1ab_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.35.1=h7274673_9\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.3.0=h5101ec6_17\n  - libgfortran-ng=7.5.0=ha8ba4b0_17\n  - libgfortran4=7.5.0=ha8ba4b0_17\n  - libgomp=9.3.0=h5101ec6_17\n  - libpng=1.6.37=hbc83047_0\n  - libsodium=1.0.18=h7b6447c_0\n  - libstdcxx-ng=9.3.0=hd4cf53a_17\n  - libtiff=4.2.0=h85742a9_0\n  - libuuid=1.0.3=h1bed415_2\n  - libuv=1.40.0=h7b6447c_0\n  - libwebp-base=1.2.0=h27cfd23_0\n  - libxcb=1.14=h7b6447c_0\n  - libxml2=2.9.10=hb55368b_3\n  - lz4-c=1.9.3=h2531618_0\n  - markupsafe=2.0.1=py37h27cfd23_0\n  - matplotlib=3.3.4=py37h06a4308_0\n  - matplotlib-base=3.3.4=py37h62a2d02_0\n  - mistune=0.8.4=py37h14c3975_1001\n  - mkl=2021.2.0=h06a4308_296\n  - mkl-service=2.3.0=py37h27cfd23_1\n  - mkl_fft=1.3.0=py37h42c9631_2\n  - mkl_random=1.2.1=py37ha9443f7_2\n  - nbclassic=0.2.6=pyhd3eb1b0_0\n  - nbclient=0.5.3=pyhd3eb1b0_0\n  - nbconvert=6.0.7=py37_0\n  - nbformat=5.1.3=pyhd3eb1b0_0\n  - ncurses=6.2=he6710b0_1\n  - nest-asyncio=1.5.1=pyhd3eb1b0_0\n  - ninja=1.10.2=hff7bd54_1\n  - notebook=6.4.0=py37h06a4308_0\n  - numpy=1.20.2=py37h2d18471_0\n  - numpy-base=1.20.2=py37hfae3a4d_0\n  - nvidiacub=1.10.0=0\n  - olefile=0.46=py37_0\n  - open3d=0.13.0=py37_0\n  - openssl=1.1.1k=h7f98852_0\n  - packaging=20.9=pyhd3eb1b0_0\n  - pandas=1.2.4=py37h2531618_0\n  - pandoc=2.12=h06a4308_0\n  - pandocfilters=1.4.3=py37h06a4308_1\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pcre=8.44=he6710b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.2.0=py37he98fc37_0\n  - pip=21.1.2=py37h06a4308_0\n  - plyfile=0.7.4=pyhd8ed1ab_0\n  - portalocker=2.3.0=py37h06a4308_0\n  - prometheus_client=0.11.0=pyhd3eb1b0_0\n  - prompt-toolkit=3.0.17=pyh06a4308_0\n  - prompt_toolkit=3.0.17=hd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pycparser=2.20=py_2\n  - pygments=2.9.0=pyhd3eb1b0_0\n  - pyopenssl=20.0.1=pyhd3eb1b0_1\n  - pyparsing=2.4.7=pyhd3eb1b0_0\n  - pyqt=5.9.2=py37h05f1152_2\n  - pyrsistent=0.17.3=py37h7b6447c_0\n  - pysocks=1.7.1=py37_1\n  - python=3.7.10=h12debd9_4\n  - python-dateutil=2.8.1=pyhd3eb1b0_0\n  - python_abi=3.7=1_cp37m\n  - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0\n  - pytz=2021.1=pyhd3eb1b0_0\n  - pyzmq=20.0.0=py37h2531618_1\n  - qt=5.9.7=h5867ecd_1\n  - qtconsole=5.1.1=pyhd8ed1ab_0\n  - qtpy=1.10.0=pyhd8ed1ab_0\n  - readline=8.1=h27cfd23_0\n  - requests=2.26.0=pyhd3eb1b0_0\n  - scikit-learn=0.24.2=py37ha9443f7_0\n  - scipy=1.6.2=py37had2a1c9_1\n  - send2trash=1.5.0=pyhd3eb1b0_1\n  - setuptools=52.0.0=py37h06a4308_0\n  - sip=4.19.8=py37hf484d3e_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sniffio=1.2.0=py37h06a4308_1\n  - sqlite=3.35.4=hdfb4753_0\n  - tabulate=0.8.9=py37h06a4308_0\n  - terminado=0.9.4=py37h06a4308_0\n  - testpath=0.4.4=pyhd3eb1b0_0\n  - threadpoolctl=2.1.0=pyh5ca1d4c_0\n  - tk=8.6.10=hbc83047_0\n  - torchvision=0.8.2=py37_cu102\n  - tornado=6.1=py37h27cfd23_0\n  - traitlets=5.0.5=pyhd3eb1b0_0\n  - typing_extensions=3.7.4.3=pyha847dfd_0\n  - wcwidth=0.2.5=py_0\n  - webencodings=0.5.1=py37_1\n  - wheel=0.36.2=pyhd3eb1b0_0\n  - widgetsnbextension=3.5.1=py37_0\n  - xz=5.2.5=h7b6447c_0\n  - yacs=0.1.6=py_0\n  - yaml=0.2.5=h7b6447c_0\n  - zeromq=4.3.4=h2531618_0\n  - zipp=3.4.1=pyhd3eb1b0_0\n  - zlib=1.2.11=h7b6447c_3\n  - zstd=1.4.9=haebb681_0\n  - pip:\n    - absl-py==0.13.0\n    - addict==2.4.0\n    - anykeystore==0.2\n    - cachetools==4.2.2\n    - cryptacular==1.5.5\n    - cython==0.29.24\n    - decorator==4.4.2\n    - easydict==1.9\n    - einops==0.3.0\n    - fire==0.4.0\n    - flow-vis==0.1\n    - freetype-py==2.2.0\n    - future==0.18.2\n    - glumpy==1.2.0\n    - google-auth==1.31.0\n    - google-auth-oauthlib==0.4.4\n    - greenlet==1.1.0\n    - grpcio==1.38.0\n    - hupper==1.10.3\n    - imageio==2.9.0\n    - llvmlite==0.36.0\n    - loguru==0.5.3\n    - markdown==3.3.4\n    - networkx==2.5.1\n    - numba==0.53.1\n    - numpy-quaternion==2021.6.9.13.34.11\n    - oauthlib==3.1.1\n    - opencv-python==4.5.2.54\n    - pastedeploy==2.1.1\n    - pbkdf2==1.3\n    - plaster==1.0\n    - plaster-pastedeploy==0.7\n    - protobuf==3.17.3\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pyassimp==4.1.3\n    - pyglet==1.5.17\n    - pyopengl==3.1.5\n    - pyopengl-accelerate==3.1.5\n    - pyramid==2.0\n    - pyramid-mailer==0.15.1\n    - python3-openid==3.2.0\n    - pywavelets==1.1.1\n    - pyyaml==5.4.1\n    - repoze-sendmail==4.4.1\n    - requests-oauthlib==1.3.0\n    - rsa==4.7.2\n    - scikit-image==0.18.1\n    - sqlalchemy==1.4.18\n    - tensorboard==2.5.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorboardx==2.2\n    - termcolor==1.1.0\n    - tifffile==2021.6.14\n    - tqdm==4.61.1\n    - transaction==3.0.1\n    - transforms3d==0.3.1\n    - translationstring==1.4\n    - triangle==20200424\n    - urllib3==1.26.5\n    - velruse==1.1.1\n    - venusian==3.0.0\n    - vispy==0.6.6\n    - webob==1.8.7\n    - werkzeug==2.0.1\n    - zope-deprecation==4.4.0\n    - zope-interface==5.4.0\n    - mmcv \nprefix: /opt/miniconda3/envs/py37_tmp\n"
  },
  {
    "path": "geometry/__init__.py",
    "content": ""
  },
  {
    "path": "geometry/cholesky.py",
    "content": "# import tensorflow as tf\nimport torch #as tf\nimport numpy as np\n# from utils.einsum import einsum\nfrom torch import einsum\n\n\n\nclass _cholesky_solve(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, H, b):\n        chol = torch.cholesky(H)\n        xx = torch.cholesky_solve(b, chol)\n        ctx.save_for_backward(chol, xx)\n\n        return xx\n\n    # see OptNet: https://arxiv.org/pdf/1703.00443.pdf\n    @staticmethod\n    def backward(ctx, dx):\n        chol, xx = ctx.saved_tensors\n\n        dz = torch.cholesky_solve(dx, chol)\n        xs = torch.squeeze(xx,  -1)\n        zs = torch.squeeze(dz, -1)\n        dH = -einsum('...i,...j->...ij', xs, zs)\n\n        return dH, dz\ndef cholesky_solve(H, b):\n    return _cholesky_solve.apply(H,b)\n\ndef solve(H, b, max_update=1.0):\n    \"\"\" Solves the linear system Hx = b, H > 0\"\"\"\n\n    # small system, solve on cpu\n    H = H.to(dtype=torch.float64) \n    b = b.to(dtype=torch.float64) \n\n    b = torch.unsqueeze(b, -1)\n    x = cholesky_solve(H, b)\n\n    # replaces nans and clip large updates\n    bad_values = torch.isnan(x) \n    x = torch.where(bad_values, torch.zeros_like(x), x)\n    x = torch.clamp(x, -max_update, max_update)\n\n    x = torch.squeeze(x, -1)\n    x = x.to(dtype=torch.float32) \n        \n    return x\n\n\n\ndef __test__():\n    import numpy as np \n    np.random.seed(0)\n    M=np.random.uniform(size=(3,3))\n    H=torch.tensor(M@M.transpose(-1,-2), requires_grad=True )\n\n    b=torch.tensor(np.random.uniform(size=(3,) ), requires_grad=True )\n\n    x= solve(H,b )\n\n    x.backward(torch.ones_like(x) )\n\n\n    print(f\"H={H}, b={b}, x={x}, grad={H.grad, b.grad}\")\n\nif __name__==\"__main__\":\n    __test__()\n"
  },
  {
    "path": "geometry/diff_render.py",
    "content": "import torch\nimport torch.nn as nn \nimport torch.nn.functional as F \n\nimport numpy \n\nfrom pytorch3d.renderer import (\n    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,\n    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n    camera_position_from_spherical_angles, HardPhongShader, PointLights,FoVPerspectiveCameras, PerspectiveCameras, SoftPhongShader, Materials\n) \ntry:\n    from pytorch3d.structures import Meshes, Textures\n    use_textures = True\nexcept:\n    from pytorch3d.structures import Meshes\n    from pytorch3d.renderer import TexturesVertex\n    from pytorch3d.renderer import TexturesVertex as Textures\n\n    use_textures = False\n\nimport pytorch3d.renderer.mesh.utils as utils\nfrom pytorch3d.io import load_obj, load_ply, load_objs_as_meshes\nfrom pytorch3d.renderer.mesh.rasterizer import Fragments\n\nfrom plyfile import PlyData\nfrom utils.furthest_point_sample import fragmentation_fps\nimport time\n\n\ndef rasterize(R, T, meshes, rasterizer, blur_radius=0):\n    # It will automatically update the camera settings -> R, T in rasterizer.camera\n    fragments = rasterizer(meshes, R=R, T=T)\n\n    # Copy from pytorch3D source code, try if it is necessary to do gradient decent\n    if blur_radius > 0.0:\n        clipped_bary_coords = utils._clip_barycentric_coordinates(\n            fragments.bary_coords\n        )\n        clipped_zbuf = utils._interpolate_zbuf(\n            fragments.pix_to_face, clipped_bary_coords, meshes\n        )\n        fragments = Fragments(\n            bary_coords=clipped_bary_coords,\n            zbuf=clipped_zbuf,\n            dists=fragments.dists,\n            pix_to_face=fragments.pix_to_face,\n        )\n    return fragments\n\ndef set_bary_coords_to_nearest(bary_coords_):\n    ori_shape = bary_coords_.shape\n    exr = bary_coords_ * (bary_coords_ < 0)\n    bary_coords_ = bary_coords_.view(-1, bary_coords_.shape[-1])\n    arg_max_idx = bary_coords_.argmax(1)\n    return torch.zeros_like(bary_coords_).scatter(1, arg_max_idx.unsqueeze(1), 1.0).view(*ori_shape) + exr\n\nclass MeshRendererWithDepth(nn.Module):\n    def __init__(self, rasterizer, shader):\n        super().__init__()\n        self.rasterizer = rasterizer\n        self.shader = shader\n\n    def to(self, device):\n        # Rasterizer and shader have submodules which are not of type nn.Module\n        self.rasterizer.to(device)\n        self.shader.to(device)\n        return self\n\n    def forward(self, meshes_world, **kwargs) -> torch.Tensor:\n        fragments = self.rasterizer(meshes_world, **kwargs)\n        images = self.shader(fragments, meshes_world, **kwargs)\n        return images, fragments.zbuf\n\nclass DiffRender(nn.Module):\n    def __init__(self, mesh_path, render_texture=False):\n        super().__init__()\n\n        # self.mesh = mesh\n        if mesh_path.endswith('.ply'):\n            verts, faces = load_ply(mesh_path)\n            self.mesh = Meshes(verts=[verts], faces=[faces])\n        elif mesh_path.endswith('.obj'):\n            verts, faces,_ = load_obj(mesh_path)\n            faces=faces.verts_idx\n            self.mesh=load_objs_as_meshes([mesh_path])\n\n        self.verts = verts\n        self.faces = faces\n        self.cam_opencv2pytch3d = torch.tensor(\n                                [[-1,0,0,0],\n                                [0,-1,0, 0],\n                                [0,0, 1, 0],\n                                [0,0, 0, 1]], dtype=torch.float32\n                                )\n        self.render_texture = render_texture\n\n        #get patch infos\n        self.pat_centers, self.pat_center_inds,  self.vert_frag_ids= fragmentation_fps(verts.detach().cpu().numpy(), 64)\n        self.pat_centers = torch.from_numpy(self.pat_centers)\n        self.pat_center_inds = torch.from_numpy(self.pat_center_inds)\n        self.vert_frag_ids = torch.from_numpy(self.vert_frag_ids)[...,None] #Nx1\n\n\n\n\n    def to(self, *args, **kwargs):\n        if 'device' in kwargs.keys():\n            device = kwargs['device']\n        else:\n            device = args[0]\n        super().to(device)\n        self.mesh = self.mesh.to(device)\n        self.verts = self.verts.to(device)\n        self.faces = self.faces.to(device)\n        self.pat_centers = self.pat_centers.to(device)\n        self.pat_center_inds = self.pat_center_inds.to(device)\n        self.vert_frag_ids = self.vert_frag_ids.to(device)\n        \n        return self\n\n    def get_patch_center_depths(self, T, K):\n        #no need to pre-transform, as here we do not use pytorch3d rendering\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        #render depths\n        X_cam= (self.pat_centers@R+t) #BxKx3\n        depth= X_cam[...,2:] #BxKx1\n        x=X_cam@K.transpose(-1,-2)  #BxNx3\n        x = x/x[...,-1:]\n        img_coords= x[...,:2]\n\n\n        return depth, img_coords \n\n    # Calculate interpolated maps -> [n, c, h, w]\n    # face_memory.shape: [n_face, 3, c]\n    @staticmethod\n    def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear', return_depth=True):\n\n        fragments = rasterize(R, t, meshes, rasterizer, blur_radius=blur_radius)\n\n        # [n, h, w, 1, d]\n        if mode == 'nearest':\n            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, set_bary_coords_to_nearest(fragments.bary_coords), face_memory)\n        else:\n            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, face_memory)\n\n        out_map = out_map.squeeze(dim=3)\n        out_map = out_map.transpose(3, 2).transpose(2, 1)\n        if return_depth:\n            return out_map, fragments.zbuf.permute(0,3,1,2) # depth\n        else:\n            return out_map\n\n    def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, lights=(1,1,-1) ):\n        B=T.shape[0]\n\n        device = T.device\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2],  R=R, T=t, image_size=[render_image_size]*B, in_ndc=False, device=device)\n        lights = PointLights(device=device, location=[lights])\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1,\n            bin_size=None, #0\n            perspective_correct=True\n        )\n        materials = Materials(\n            device=device,\n            # specular_color=[[0.0, 1.0, 0.0]],\n            shininess=0\n        )\n        renderer = MeshRendererWithDepth(\n            rasterizer=MeshRasterizer(\n                cameras=cameras, \n                raster_settings=raster_settings\n            ),\n            shader=SoftPhongShader(\n                device=device, \n                cameras=cameras,\n                lights=lights, \n                blend_params=BlendParams(1e-4, 1e-4, (0, 0, 0))\n            )\n        )\n        image,depth =renderer(self.mesh, lights=lights, materials=materials)\n\n        return image.permute(0,3,1,2)[:,:3], depth.permute(0,3,1,2) # to BCHW\n\n    def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):\n        yy, xx = torch.meshgrid(torch.arange(render_image_size[0], device=T.device), torch.arange(render_image_size[1], device=T.device) )\n        # xx = xx.to(dtype=torch.float32)\n        # yy = yy.to(dtype=torch.float32)\n        coords_grid = torch.stack( [ xx.to(dtype=torch.float32),  yy.to(dtype=torch.float32)], dim=-1 )\n\n        #no need to pre-transform, as here we do not use pytorch3d rendering\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        #render depths\n        X_cam= (self.pat_centers@R+t)\n        x=X_cam@K.transpose(-1,-2)  #BxNx3\n        x = x/x[...,-1:]\n\n        offset = x[...,None,None,:2] - coords_grid #BxNx1x1x2-HxWx2\n        \n        return offset.permute(0,1,4,2,3) #BxNx2xHxW\n\n    def forward(self, vert_attribute, T, K, render_image_size, near=0.1, far=6, mode='bilinear') :\n        \"\"\"\n        Args:\n            vert_attribute: (N,C)\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n        \"\"\"\n\n        if vert_attribute is None:\n            return self.render_mesh(T, K, render_image_size, near=0.1, far=6 )\n        if self.render_texture:\n            ren_tex=self.render_mesh(T, K, render_image_size, near=0.1, far=6 )\n\n\n        B=T.shape[0]\n        face_attribute = vert_attribute[self.faces.long()]\n\n        device = T.device\n\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n        # t = -(R@T[...,:3,3:]).squeeze(-1)\n\n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1,\n            bin_size=None, #0\n            perspective_correct=True\n        )\n\n        rasterizer = MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings\n        )\n\n        out_map, out_depth=self.forward_interpolate(R, t, self.mesh, face_attribute, rasterizer, blur_radius=0, mode=mode)\n        \n        if not self.render_texture:\n            return out_map, out_depth\n        else:\n            return torch.cat([ren_tex[0], out_map ], dim=1), out_depth\n\n    def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode='neareast'):\n        \"\"\"\n        Args:\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n            mode: 'bilinear' or 'neareast'\n        \"\"\"\n\n        B=T.shape[0]\n        device = T.device\n\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1,\n            bin_size=0\n        )\n\n        rasterizer = MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings\n        )\n\n\n        #render depths\n        vert_depths= (self.verts@R+t).squeeze(0)[...,2:]\n        face_depths = vert_depths[self.faces.long()]\n        out_depth=self.forward_interpolate(R, t, self.mesh, face_depths, rasterizer, blur_radius=0, mode='nearest', return_depth=False)\n\n        return out_depth\n\n\nclass DiffRendererWrapper(nn.Module):\n    def __init__(self, obj_paths, device=\"cuda\", render_texture=False ):\n        super().__init__()\n\n        self.renderers = []\n        for obj_path in obj_paths:\n            self.renderers.append( \n                DiffRender(obj_path, render_texture).to(device=device)\n            )\n\n        self.renderers=nn.ModuleList(self.renderers)\n        self.cls2idx=None #could be updated outside\n\n    def get_patch_center_depths(self, model_names, T, K):\n        \n        depths= []\n        image_coords= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n            depth, img_coord = self.renderers[model_idx].get_patch_center_depths(T[b:b+1], K )\n            depths.append(depth)\n            image_coords.append(img_coord)\n        \n        return torch.cat(depths, dim=0), torch.cat(image_coords, dim=0)\n\n\n    def render_offset_map(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n        offsets= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            offset = self.renderers[model_idx].render_offset_map(T[b:b+1], K[b:b+1], render_image_size, near, far )\n            offsets.append(offset)\n        \n        return torch.cat(offsets, dim=0)\n\n    def render_pat_id(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n\n        pat_ids= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n            \n            pat_id,_ = self.renderers[model_idx].forward(self.renderers[model_idx].vert_frag_ids.float()+1,T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )\n            pat_ids.append(pat_id-1) #+1 -1, set invalid parts as -1's  \n        \n        return torch.cat(pat_ids, dim=0)\n\n    def render_depth(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n\n        depth_outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            depth = self.renderers[model_idx].render_depth( T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )\n            depth_outputs.append(depth)\n        \n        return torch.cat(depth_outputs, dim=0)\n\n    def forward(self, model_names,  vert_attribute, T, K, render_image_size, near=0.1, far=6):\n\n        map_outputs= []\n        depth_outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            feamap, depth= self.renderers[model_idx]( vert_attribute[b], T[b:b+1], K[b:b+1], render_image_size, near, far )\n\n            map_outputs.append(feamap)\n            depth_outputs.append(depth)\n        return torch.cat(map_outputs, dim=0) , torch.cat(depth_outputs, dim=0)\n\n\n"
  },
  {
    "path": "geometry/diff_render_optim.py",
    "content": "## Speed optimized: sharing the rasterization among different rendering process\n\nimport torch\nimport torch.nn as nn \nimport torch.nn.functional as F \n\nimport numpy \n\nfrom pytorch3d.renderer import (\n    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,\n    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n    camera_position_from_spherical_angles, HardPhongShader, PointLights,FoVPerspectiveCameras, PerspectiveCameras, SoftPhongShader, Materials \n) \ntry:\n    from pytorch3d.structures import Meshes, Textures\n    use_textures = True\nexcept:\n    from pytorch3d.structures import Meshes\n    from pytorch3d.renderer import TexturesVertex\n    from pytorch3d.renderer import TexturesVertex as Textures\n\n    use_textures = False\n\nimport pytorch3d.renderer.mesh.utils as utils\nfrom pytorch3d.io import load_obj, load_ply, load_objs_as_meshes\nfrom pytorch3d.renderer.mesh.rasterizer import Fragments\n\nfrom plyfile import PlyData\nfrom utils.furthest_point_sample import fragmentation_fps\n\n\n\ndef rasterize(R, T, meshes, rasterizer, blur_radius=0):\n    # It will automatically update the camera settings -> R, T in rasterizer.camera\n    fragments = rasterizer(meshes, R=R, T=T)\n\n    # Copy from pytorch3D source code, try if it is necessary to do gradient decent\n    if blur_radius > 0.0:\n        clipped_bary_coords = utils._clip_barycentric_coordinates(\n            fragments.bary_coords\n        )\n        clipped_zbuf = utils._interpolate_zbuf(\n            fragments.pix_to_face, clipped_bary_coords, meshes\n        )\n        fragments = Fragments(\n            bary_coords=clipped_bary_coords,\n            zbuf=clipped_zbuf,\n            dists=fragments.dists,\n            pix_to_face=fragments.pix_to_face,\n        )\n    return fragments\n\ndef set_bary_coords_to_nearest(bary_coords_):\n    ori_shape = bary_coords_.shape\n    exr = bary_coords_ * (bary_coords_ < 0)\n    bary_coords_ = bary_coords_.view(-1, bary_coords_.shape[-1])\n    arg_max_idx = bary_coords_.argmax(1)\n    return torch.zeros_like(bary_coords_).scatter(1, arg_max_idx.unsqueeze(1), 1.0).view(*ori_shape) + exr\n\nclass MeshRendererWithDepth(nn.Module):\n    def __init__(self, rasterizer, shader):\n        super().__init__()\n        self.rasterizer = rasterizer\n        self.shader = shader\n\n    def to(self, device):\n        # Rasterizer and shader have submodules which are not of type nn.Module\n        self.rasterizer.to(device)\n        self.shader.to(device)\n        return self\n\n    def forward(self, meshes_world, **kwargs) -> torch.Tensor:\n        fragments = self.rasterizer(meshes_world, **kwargs)\n        images = self.shader(fragments, meshes_world, **kwargs)\n        return images, fragments.zbuf\n\nclass MeshRendererWithDepth_v2(nn.Module):\n    def __init__(self, rasterizer, shader):\n        super().__init__()\n        self.rasterizer = rasterizer\n        self.shader = shader\n\n    # def to(self, device):\n    def to(self, *args, **kwargs):\n        if 'device' in kwargs.keys():\n            device = kwargs['device']\n        else:\n            device = args[0]\n        super().to(device)\n        # Rasterizer and shader have submodules which are not of type nn.Module\n        self.rasterizer.to(device)\n        self.shader.to(device)\n        return self\n\n    def forward(self, meshes_world, **kwargs) -> torch.Tensor:\n        if 'fragments' not in kwargs.keys() or kwargs['fragments'] is None: # sharing fragment results with others for speed, as the rasterizing process occupies most of time\n            if 'fragments' in kwargs:\n                del kwargs['fragments']\n                \n            fragments = self.rasterizer(meshes_world, **kwargs)\n        else:\n            fragments = kwargs['fragments']\n            del kwargs['fragments']\n\n        images = self.shader(fragments, meshes_world, **kwargs)\n        return images, fragments.zbuf\n\nclass DiffRender(nn.Module):\n    def __init__(self, mesh_path, render_texture=False):\n        super().__init__()\n\n        # self.mesh = mesh\n        if mesh_path.endswith('.ply'):\n            verts, faces = load_ply(mesh_path)\n            self.mesh = Meshes(verts=[verts], faces=[faces])\n        elif mesh_path.endswith('.obj'):\n            verts, faces,_ = load_obj(mesh_path)\n            # import pdb; pdb.set_trace()\n            faces=faces.verts_idx\n            self.mesh=load_objs_as_meshes([mesh_path])\n\n        # self.mesh = Meshes(verts=verts, faces=faces, textures=None)\n        self.verts = verts\n        self.faces = faces\n        # self.mesh = Meshes(verts=[verts], faces=[faces])\n        # self.feature=feature\n        self.cam_opencv2pytch3d = torch.tensor(\n                                [[-1,0,0,0],\n                                [0,-1,0, 0],\n                                [0,0, 1, 0],\n                                [0,0, 0, 1]], dtype=torch.float32\n                                )\n        self.render_texture = render_texture\n\n        #get patch infos\n        self.pat_centers, self.pat_center_inds,  self.vert_frag_ids= fragmentation_fps(verts.detach().cpu().numpy(), 64)\n        self.pat_centers = torch.from_numpy(self.pat_centers)\n        self.pat_center_inds = torch.from_numpy(self.pat_center_inds)\n        self.vert_frag_ids = torch.from_numpy(self.vert_frag_ids)[...,None] #Nx1\n\n\n\n\n    def to(self, *args, **kwargs):\n        if 'device' in kwargs.keys():\n            device = kwargs['device']\n        else:\n            device = args[0]\n        super().to(device)\n        # self.rasterizer.cameras = self.rasterizer.cameras.to(device)\n        # self.face_memory = self.face_memory.to(device)\n        self.mesh = self.mesh.to(device)\n        self.verts = self.verts.to(device)\n        self.faces = self.faces.to(device)\n        self.pat_centers = self.pat_centers.to(device)\n        self.pat_center_inds = self.pat_center_inds.to(device)\n        self.vert_frag_ids = self.vert_frag_ids.to(device)\n\n        \n        # self.cam_opencv2pytch3d = self.cam_opencv2pytch3d.to(device=device)\n        return self\n\n    def get_patch_center_depths(self, T, K):\n\n        #no need to pre-transform, as here we do not use pytorch3d rendering\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        #render depths\n        X_cam= (self.pat_centers@R+t) #BxKx3\n        depth= X_cam[...,2:] #BxKx1\n        x=X_cam@K.transpose(-1,-2)  #BxNx3\n        x = x/x[...,-1:]\n        img_coords= x[...,:2]\n\n\n        return depth, img_coords \n\n    # Calculate interpolated maps -> [n, c, h, w]\n    # face_memory.shape: [n_face, 3, c]\n    @staticmethod\n    def forward_interpolate(R, t, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear', return_depth=True):\n\n        fragments = rasterize(R, t, meshes, rasterizer, blur_radius=blur_radius)\n\n        # [n, h, w, 1, d]\n        if mode == 'nearest':\n            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, set_bary_coords_to_nearest(fragments.bary_coords), face_memory)\n        else:\n            out_map = utils.interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, face_memory)\n        out_map = out_map.squeeze(dim=3)\n        out_map = out_map.transpose(3, 2).transpose(2, 1)\n        if return_depth:\n            return out_map, fragments.zbuf.permute(0,3,1,2), fragments # depth\n        else:\n            return out_map, fragments\n\n    def render_mesh(self,  T, K, render_image_size, near=0.1, far=6, lights=(1,1,-1), fragments=None ):\n        B=T.shape[0]\n        # face_attribute = vert_attribute[self.faces.long()]\n\n        device = T.device\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2],  R=R, T=t, image_size=[render_image_size]*B, in_ndc=False, device=device)\n        lights = PointLights(device=device, location=[lights])\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1, #5,\n            bin_size=None, #0\n            perspective_correct=True\n        )\n        materials = Materials(\n            device=device,\n            # specular_color=[[0.0, 1.0, 0.0]],\n            shininess=0\n        )\n        # renderer = MeshRendererWithDepth(\n        renderer = MeshRendererWithDepth_v2(\n            rasterizer=MeshRasterizer(\n                cameras=cameras, \n                raster_settings=raster_settings\n            ),\n            shader=SoftPhongShader(\n            # shader=SoftGouraudShader(\n                device=device, \n                cameras=cameras,\n                lights=lights, \n                blend_params=BlendParams(1e-4, 1e-4, (0, 0, 0))\n            )\n        )\n        image,depth =renderer(self.mesh, lights=lights, materials=materials, fragments=fragments)\n\n        return image.permute(0,3,1,2)[:,:3], depth.permute(0,3,1,2) # to BCHW\n\n    def render_offset_map(self,  T, K, render_image_size, near=0.1, far=6):\n        yy, xx = torch.meshgrid(torch.arange(render_image_size[0], device=T.device), torch.arange(render_image_size[1], device=T.device) )\n        # xx = xx.to(dtype=torch.float32)\n        # yy = yy.to(dtype=torch.float32)\n        coords_grid = torch.stack( [ xx.to(dtype=torch.float32),  yy.to(dtype=torch.float32)], dim=-1 )\n\n        #no need to pre-transform, as here we do not use pytorch3d rendering\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        #render depths\n        X_cam= (self.pat_centers@R+t)#.squeeze(0)[...,2:]\n        x=X_cam@K.transpose(-1,-2)  #BxNx3\n        x = x/x[...,-1:]\n\n        offset = x[...,None,None,:2] - coords_grid #BxNx1x1x2-HxWx2\n        \n        return offset.permute(0,1,4,2,3) #BxNx2xHxW\n\n    # def forward(self, face_attribute, T, K, render_image_size, near=0.1, far=6):\n    def forward(self, vert_attribute, T, K, render_image_size, near=0.1, far=6, render_texture=None, mode='bilinear') :\n        \"\"\"\n        Args:\n            vert_attribute: (N,C)\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n        \"\"\"\n\n        # use default rendering settings \n        if render_texture is None:\n            render_texture= self.render_texture \n            \n        if vert_attribute is None:\n            # only render the rgb image\n            return self.render_mesh(T, K, render_image_size, near=0.1, far=6 )\n\n        B=T.shape[0]\n        face_attribute = vert_attribute[self.faces.long()]\n\n        device = T.device\n\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n        # t = -(R@T[...,:3,3:]).squeeze(-1)\n        \n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1,\n            bin_size=None, #0\n            perspective_correct=True\n        )\n\n        rasterizer = MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings\n        )\n\n        # forward_interpolate(R, T, meshes, face_memory, rasterizer, blur_radius=0, mode='bilinear')\n        out_map, out_depth, fragments=self.forward_interpolate(R, t, self.mesh, face_attribute, rasterizer, blur_radius=0, mode=mode)\n        \n        if not render_texture:\n            return out_map, out_depth\n        else:\n            ren_tex=self.render_mesh(T, K, render_image_size, near=0.1, far=6, fragments=fragments  )\n\n            #The first 3 channels contain the rendered textures\n            return torch.cat([ren_tex[0], out_map ], dim=1), out_depth\n\n    def render_depth(self, T, K, render_image_size, near=0.1, far=6, mode='neareast'):\n        \"\"\"\n        Args:\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n            mode: 'bilinear' or 'neareast'\n        \"\"\"\n\n        B=T.shape[0]\n        device = T.device\n\n        T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n        cameras = PerspectiveCameras(focal_length= torch.stack([K[:,0,0], K[:,1,1] ], dim=-1), \n            principal_point=K[:,:2,2], image_size=[render_image_size]*B, in_ndc=False, device=device)\n\n        raster_settings = RasterizationSettings(\n            image_size=render_image_size,\n            blur_radius=0.0,\n            faces_per_pixel=1,\n            bin_size=0\n        )\n\n        rasterizer = MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings\n        )\n\n\n        #render depths\n        vert_depths= (self.verts@R+t).squeeze(0)[...,2:]\n        face_depths = vert_depths[self.faces.long()]\n        out_depth, _ =self.forward_interpolate(R, t, self.mesh, face_depths, rasterizer, blur_radius=0, mode='nearest', return_depth=False)\n\n        return out_depth\n\n    def render_pointcloud(self, T, K, render_image_size, near=0.1, far=6):\n        \"\"\"\n        Args:\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n            mode: 'bilinear' or 'neareast'\n        \"\"\"\n\n        B=T.shape[0]\n        device = T.device\n\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        R = T[...,:3,:3].transpose(-1,-2)\n        t = T[...,:3,3]\n\n        #render depths\n        # vert_depths= (self.verts@R+t).squeeze(0)[...,2:]\n        X_cam= (self.verts@R+t)#.squeeze(0)\n\n        x=X_cam@K.transpose(-1,-2)  #BxNx3\n        depth = x[...,-1]\n        x = x/x[...,-1:]\n\n        out = torch.zeros([1,1, *render_image_size], dtype=R.dtype, device=R.device)\n        out[:, :, \n            torch.round(x[0, :, 1]).long().clamp(0, out.shape[2]-1),\n            torch.round(x[0, :, 0]).long().clamp(0, out.shape[3]-1)] = depth \n\n        return out #1x1xHxW\n\n\nclass DiffRendererWrapper(nn.Module):\n    def __init__(self, obj_paths, device=\"cuda\", render_texture=False ):\n        super().__init__()\n\n        self.renderers = []\n        for obj_path in obj_paths:\n            self.renderers.append( \n                DiffRender(obj_path, render_texture).to(device=device)\n            )\n\n        self.renderers=nn.ModuleList(self.renderers)\n        self.cls2idx=None #updated outside\n\n    def get_patch_center_depths(self, model_names, T, K):\n        \n        depths= []\n        image_coords= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n            depth, img_coord = self.renderers[model_idx].get_patch_center_depths(T[b:b+1], K )\n            depths.append(depth)\n            image_coords.append(img_coord)\n        \n        return torch.cat(depths, dim=0), torch.cat(image_coords, dim=0)\n\n\n    def render_offset_map(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n        offsets= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            offset = self.renderers[model_idx].render_offset_map(T[b:b+1], K[b:b+1], render_image_size, near, far )\n            offsets.append(offset)\n        \n        return torch.cat(offsets, dim=0)\n\n    def render_pat_id(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n\n        pat_ids= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n            # face_pat_id = self.renderers[model_idx].vert_frag_ids[self.renderers[model_idx].faces.long()]\n            \n            pat_id,_ = self.renderers[model_idx].forward(self.renderers[model_idx].vert_frag_ids.float()+1,T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )\n            pat_ids.append(pat_id-1) #+1 -1, set invalid parts as -1's  \n        \n        return torch.cat(pat_ids, dim=0)\n\n    def render_depth(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n    \n        depth_outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            depth = self.renderers[model_idx].render_depth( T[b:b+1], K[b:b+1], render_image_size, near, far, 'nearest' )\n            depth_outputs.append(depth)\n        \n        return torch.cat(depth_outputs, dim=0)\n    def render_mesh(self, model_names,  T, K, render_image_size, near=0.1, far=6):\n\n        outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            img= self.renderers[model_idx].render_mesh( T[b:b+1], K[b:b+1], render_image_size, near, far, )[0]\n            outputs.append(img)\n        \n        return torch.cat(outputs, dim=0)\n\n    def render_pointcloud(self, model_names, T, K, render_image_size, near=0.1, far=6):\n        outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n            depth = self.renderers[model_idx].render_pointcloud( T[b:b+1], K[b:b+1], render_image_size, near, far )\n            outputs.append(depth)\n        \n        return torch.cat(outputs, dim=0)\n\n    def forward(self, model_names,  vert_attribute, T, K, render_image_size, near=0.1, far=6, render_tex=False):\n\n        map_outputs= []\n        depth_outputs= []\n        for b,_ in enumerate(model_names):\n            model_idx = self.cls2idx[model_names[b]]\n\n            feamap, depth= self.renderers[model_idx]( vert_attribute[b], T[b:b+1], K[b:b+1], render_image_size, near, far, render_texture=render_tex )\n\n            map_outputs.append(feamap)\n            depth_outputs.append(depth)\n        return torch.cat(map_outputs, dim=0) , torch.cat(depth_outputs, dim=0)\n\n\n"
  },
  {
    "path": "geometry/einsum.py",
    "content": "# import tensorflow as torch\nimport torch as torch\n\nimport numpy as np\nimport re\nimport string\n\ndef einsum(equation, *inputs):\n\n    equation = equation.replace(' ', '')\n    # input_shapes = [x.get_shape() for x in list(inputs)]\n    input_shapes = [x.shape for x in list(inputs)]\n    match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation)\n    if not match:\n        raise ValueError('Indices have incorrect format: %s' % equation)\n\n    input_axis_labels = match.group(1).split(',')\n    output_axis_labels = match.group(2)[2:] if match.group(2) else None\n\n    if len(input_shapes) != len(input_axis_labels):\n        raise ValueError('Got %d arguments for equation \"%s\", expecting %d' %\n                        (len(input_shapes), equation, len(input_axis_labels)))\n\n    # Resolve Ellipsis\n    # Assign axes labels for unspecified dimensions in inputs. Labels taken\n    # from unused labels. Follow numpy einsum broadcasting conventions for\n    # tensors of different length and unlabeled output.\n    ellipsis_axes = ''\n    if '...' in equation:\n        unused = ''.join([c for c in string.ascii_lowercase\n                        if c not in ''.join(input_axis_labels)])\n        for i, ax in enumerate(input_axis_labels):\n            if '...' in ax:\n                parts = ax.split('...')\n                if len(parts) != 2:\n                    raise ValueError('Unable to resolve ellipsis. Excess number found.')\n\n                # n = input_shapes[i].ndims - len(''.join(parts))\n                n = len(input_shapes[i]) - len(''.join(parts))\n                if n < 0:\n                    raise ValueError('Ellipses lengths do not match.')\n                if len(unused) < n:\n                    raise ValueError(\n                        'Unable to resolve ellipsis, too many distinct labels.')\n                replace_axes = unused[-n:] if n > 0 else ''\n                input_axis_labels[i] = input_axis_labels[i].replace('...',\n                                                                    replace_axes)\n                if len(replace_axes) > len(ellipsis_axes):\n                    ellipsis_axes = replace_axes\n                    \n    equation = equation.replace('...', ellipsis_axes)\n    out = torch.einsum(equation, *inputs)\n    # torch.add_to_collection(\"checkpoints\", out)\n    return out\n"
  },
  {
    "path": "geometry/intrinsics.py",
    "content": "import torch\nimport numpy as np\n# from utils.einsum import einsum\nfrom .einsum import einsum\n\ndef intrinsics_vec_to_matrix(kvec):\n    fx, fy, cx, cy = torch.unbind(kvec, dim=-1)\n    z = torch.zeros_like(fx)\n    o = torch.ones_like(fx)\n\n    K = torch.stack([fx, z, cx, z, fy, cy, z, z, o], dim=-1)\n    K = torch.reshape(K, list(kvec.shape)[:-1] + [3,3])\n    return K\n\ndef intrinsics_matrix_to_vec(kmat):\n    fx = kmat[..., 0, 0]\n    fy = kmat[..., 1, 1]\n    cx = kmat[..., 0, 2]\n    cy = kmat[..., 1, 2]\n    return torch.stack([fx, fy, cx, cy], dim=-1)\n\ndef update_intrinsics(intrinsics, delta_focal):\n    kvec = intrinsics_matrix_to_vec(intrinsics)\n    fx, fy, cx, cy = torch.unstack(kvec, num=4, axis=-1)\n    df = torch.squeeze(delta_focal, -1)\n\n    # update the focal lengths\n    fx = torch.exp(df) * fx\n    fy = torch.exp(df) * fy\n\n    kvec = torch.stack([fx, fy, cx, cy], axis=-1)\n    kmat = intrinsics_vec_to_matrix(kvec)\n    return kmat\n\ndef rescale_depth(depth, downscale=4):\n    depth = depth[:,None]\n    new_shape = [depth.shape[-2]//downscale, depth.shape[-1]//downscale]\n    depth = torch.nn.functional.interpolate(depth, new_shape, mode='nearest')\n    return torch.squeeze(depth, dim=1)\n\ndef rescale_depth_and_intrinsics(depth, intrinsics, downscale=4):\n    sc = torch.tensor([1.0/downscale, 1.0/downscale, 1.0], dtype=torch.float32, device=depth.device)\n    intrinsics = einsum('...ij,i->...ij', intrinsics, sc)\n    depth = rescale_depth(depth, downscale=downscale)\n    return depth, intrinsics\n\ndef rescale_depths_and_intrinsics(depth, intrinsics, downscale=4):\n    batch, frames, height, width = [depth.shape[i] for i in range(4)]\n    depth = torch.reshape(depth, [batch*frames, height, width])\n    depth, intrinsics = rescale_depth_and_intrinsics(depth, intrinsics, downscale)\n    depth = torch.reshape(depth,\n        [batch, frames]+list(depth.shape)[1:])\n    return depth, intrinsics\n"
  },
  {
    "path": "geometry/projective_ops.py",
    "content": "import numpy as np\nimport torch \n\n# from utils.einsum import einsum\nfrom torch import einsum\n\n\n# MIN_DEPTH = 0.1\nMIN_DEPTH = 0.01\n\ndef normalize_coords_grid(coords):\n    \"\"\" normalize the coordinates to [-1,1]\n\n    Args:\n        coords: BxKxHxWx2\n    \"\"\"\n    coords=coords.clone()\n    B,K,H,W,_ = coords.shape\n\n    coords[...,0] = 2*coords[...,0]/(W-1)-1\n    coords[...,1] = 2*coords[...,1]/(H-1)-1\n\n    return coords\n\ndef coords_grid(ref, homogeneous=True):\n    \"\"\" grid of pixel coordinates \"\"\"\n    shape = ref.shape\n\n    yy, xx = torch.meshgrid(torch.arange(shape[-2], device=ref.device), torch.arange(shape[-1], device=ref.device) )\n\n    xx = xx.to(dtype=torch.float32)\n    yy = yy.to(dtype=torch.float32)\n\n    if homogeneous:\n        coords = torch.stack([xx, yy, torch.ones_like(xx)], dim=-1)\n    else:\n        coords = torch.stack([xx, yy], dim=-1)\n\n    new_shape = [1]*len(shape[:-2]) +  list(shape[-2:]) + [-1]\n    coords = torch.reshape(coords, new_shape)\n\n    tile = list(shape[:-2])+ [1,1,1]\n    coords = coords.repeat(tile)\n    return coords # BxKxHxWx2\n\n\ndef extract_and_reshape_intrinsics(intrinsics, shape=None):\n    \"\"\" Extracts (fx, fy, cx, cy) from intrinsics matrix \"\"\"\n\n    fx = intrinsics[:, 0, 0]\n    fy = intrinsics[:, 1, 1]\n    cx = intrinsics[:, 0, 2]\n    cy = intrinsics[:, 1, 2]\n\n    if shape is not None:\n        batch = list(fx.shape[:1])\n        fillr = [1]*len(shape[1:]) \n        k_shape = batch+fillr\n\n        fx = torch.reshape(fx, k_shape)\n        fy = torch.reshape(fy, k_shape)\n        cx = torch.reshape(cx, k_shape)\n        cy = torch.reshape(cy, k_shape)\n\n    return (fx, fy, cx, cy)\n\n\ndef backproject(depth, intrinsics, jacobian=False, depth_coords=None):\n    \"\"\" backproject depth map to point cloud \"\"\"\n    #depth_coords: (BxKxHxWx2)\n\n    if depth_coords is None:\n        coords = coords_grid(depth, homogeneous=True)\n        x, y, _ = torch.unbind(coords, axis=-1)\n    else:\n        x, y =  torch.unbind(depth_coords, axis=-1)\n\n    x_shape = x.shape \n    \n    fx, fy, cx, cy = extract_and_reshape_intrinsics(intrinsics, x_shape) #Bx1x1x1\n\n    Z = depth  #BxKxHxW\n    X = Z * (x - cx) / fx\n    Y = Z * (y - cy) / fy \n    points = torch.stack([X, Y, Z], axis=-1)\n\n    if jacobian:\n        o = torch.zeros_like(Z) # used to fill in zeros\n\n        # jacobian w.r.t (fx, fy) , of shape BxKxHxWx4x1\n        jacobian_intrinsics = torch.stack([\n            torch.stack([-X / fx], dim=-1),\n            torch.stack([-Y / fy], dim=-1),\n            torch.stack([o], dim=-1),\n            torch.stack([o], dim=-1)], axis=-2)\n\n        return points, jacobian_intrinsics\n    \n    return points\n    # return points, coords\n\n\ndef project(points, intrinsics, jacobian=False):\n    \n    \"\"\" project point cloud onto image \"\"\"\n    X, Y, Z = torch.unbind(points, axis=-1)\n    Z = torch.clamp(Z, min=MIN_DEPTH)\n\n    x_shape = X.shape\n    fx, fy, cx, cy = extract_and_reshape_intrinsics(intrinsics, x_shape)\n\n    x = fx * (X / Z) + cx\n    y = fy * (Y / Z) + cy\n    coords = torch.stack([x, y], axis=-1)\n\n    if jacobian:\n        o = torch.zeros_like(x) # used to fill in zeros\n        zinv1 = torch.where(Z <= MIN_DEPTH+.01, torch.zeros_like(Z), 1.0 / Z)\n        zinv2 = torch.where(Z <= MIN_DEPTH+.01, torch.zeros_like(Z), 1.0 / Z**2)\n\n        # jacobian w.r.t (X, Y, Z)\n        jacobian_points = torch.stack([\n            torch.stack([fx * zinv1, o, -fx * X * zinv2], axis=-1),\n            torch.stack([o, fy * zinv1, -fy * Y * zinv2], axis=-1)], axis=-2)\n\n        # jacobian w.r.t (fx, fy)\n        jacobian_intrinsics = torch.stack([\n            torch.stack([X * zinv1], axis=-1),\n            torch.stack([Y * zinv1], axis=-1),], axis=-2)\n\n        return coords, (jacobian_points, jacobian_intrinsics)\n\n    return coords\n"
  },
  {
    "path": "geometry/se3.py",
    "content": "\"\"\"\nSO3 and SE3 operations, exponentials and logarithms adapted from Sophus\n\"\"\"\n\nimport numpy as np\nimport torch\nfrom .einsum import einsum\n\n\nMIN_THETA = 1e-4\n\ndef matdotv(A,b):\n    return torch.squeeze(torch.matmul(A, torch.expand_dims(b, -1)), -1)\n\ndef hat(a):\n    a1, a2, a3 = torch.split(a, [1,1,1], dim=-1)\n    zz = torch.zeros_like(a1)\n\n    ax = torch.stack([\n        torch.cat([zz,-a3,a2], dim=-1),\n        torch.cat([a3,zz,-a1], dim=-1),\n        torch.cat([-a2,a1,zz], dim=-1)\n    ], dim=-2)\n\n    return ax\n    \n\n### quaternion functions ###\n\ndef quaternion_rotate_point(q, pt, eq=None):\n    if eq is None:\n        w, vec = torch.split(q, [1, 3], axis=-1)\n        uv = 2*matdotv(hat(vec), pt)\n        return pt + w*uv + matdotv(hat(vec), uv)\n    else:\n        w, vec = torch.split(q, [1, 3], axis=-1)\n        uv1 = 2*einsum(eq, hat(w*vec), pt)\n        uv2 = 2*einsum(eq, hat(vec), pt)\n        return pt + uv1 + einsum(eq, hat(vec), uv2)\n\ndef quaternion_rotate_matrix(q, mat, eq=None):\n    if eq is None:\n        w, vec = torch.split(q, [1, 3], axis=-1)\n        uv = 2*torch.matmul(hat(vec), mat)\n        return mat + w*uv + torch.matmul(hat(vec), uv)\n    else:\n        w, vec = torch.split(q, [1, 3], axis=-1)\n        uv1 = 2*einsum(eq, hat(w*vec), mat)\n        uv2 = 2*einsum(eq, hat(vec), mat)\n        return mat + uv1 + einsum(eq, hat(vec), uv2)\n\ndef quaternion_inverse(q):\n    return q * [1, -1, -1, -1]\n\ndef quaternion_multiply(a, b):\n    aw, ax, ay, az = torch.split(a, [1,1,1,1], axis=-1)\n    bw, bx, by, bz = torch.split(b, [1,1,1,1], axis=-1)\n    \n    q = torch.concat([\n        aw * bw - ax * bx - ay * by - az * bz,\n        aw * bx + ax * bw + ay * bz - az * by,\n        aw * by + ay * bw + az * bx - ax * bz,\n        aw * bz + az * bw + ax * by - ay * bx,\n    ], axis=-1)\n\n    return q\n\ndef quaternion_to_matrix(q):\n    w, x, y, z = torch.split(q, [1,1,1,1], axis=-1)\n\n    r11 = 1 - 2 * y**2 - 2 * z**2\n    r12 = 2 * x * y - 2 * w * z\n    r13 = 2 * z * x + 2 * w * y\n\n    r21 = 2 * x * y + 2 * w * z\n    r22 = 1 - 2 * x**2 - 2 * z**2\n    r23 = 2 * y * z - 2 * w * x\n\n    r31 = 2 * z * x - 2 * w * y\n    r32 = 2 * y * z + 2 * w * x\n    r33 = 1 - 2 * x**2 - 2 * y**2\n    \n    R = torch.stack([\n        torch.concat([r11,r12,r13], axis=-1),\n        torch.concat([r21,r22,r23], axis=-1),\n        torch.concat([r31,r32,r33], axis=-1)\n    ], axis=-2)\n\n    return R\n\ndef rotation_matrix_to_quaternion(R):\n    Rxx, Ryx, Rzx = R[...,0,0], R[...,0,1], R[...,0,2]\n    Rxy, Ryy, Rzy = R[...,1,0], R[...,1,1], R[...,1,2]\n    Rxz, Ryz, Rzz = R[...,2,0], R[...,2,1], R[...,2,2]\n\n    zz = torch.zeros_like(Rxx)\n    k1 = torch.stack([Rxx-Ryy-Rzz, zz, zz, zz], axis=-1)\n    k2 = torch.stack([Ryx+Rxy, Ryy-Rxx-Rzz, zz, zz], axis=-1)\n    k3 = torch.stack([Rzx+Rxz, Rzy+Ryz, Rzz-Rxx-Ryy,zz], axis=-1)\n    k4 = torch.stack([Ryz-Rzy, Rzx-Rxz, Rxy-Ryx, Rxx+Ryy+Rzz], axis=-1)\n\n    K = torch.stack([k1, k2, k3, k4], axis=-2)\n    eigvals, eigvecs = torch.linalg.eigh(K)\n\n    x, y, z, w = torch.split(eigvecs[...,-1], [1,1,1,1], axis=-1)\n    qvec = torch.concat([w, x, y, z], axis=-1)\n    qvec /= torch.sqrt(torch.reduce_sum(qvec**2, axis=-1, keepdims=True))\n\n    return qvec * torch.sign(w)\n\ndef so3_expm_and_theta(omega):\n    \"\"\" omega in \\so3 \"\"\"\n    theta_sq = torch.reduce_sum(omega**2, axis=-1)\n    theta = torch.sqrt(theta_sq)\n    half_theta = 0.5*theta\n\n    ### small ###\n    imag_factor = torch.where(theta>MIN_THETA, \n        torch.sin(half_theta) / (theta + 1e-12), \n        0.5 - (1.0/48.0)*theta_sq + (1.0/3840.0)*theta_sq*theta_sq)\n\n    real_factor = torch.where(theta>MIN_THETA, torch.cos(half_theta),\n        1.0 - (1.0/8.0)*theta_sq + (1.0/384.0)*theta_sq*theta_sq)\n\n    qw = real_factor\n    qx = imag_factor * omega[...,0]\n    qy = imag_factor * omega[...,1]\n    qz = imag_factor * omega[...,2]\n\n    quat = torch.stack([qw, qx, qy, qz], axis=-1)\n    return quat, theta\n        \ndef so3_logm_and_theta(so3):\n    w, vec = torch.split(so3, [1,3], axis=-1)\n    squared_n = torch.reduce_sum(vec**2, axis=-1, keepdims=True)\n    n = torch.sqrt(squared_n)\n\n    two_atan_nbyw_by_n = torch.where(n<MIN_THETA,\n        2/w - w*squared_n / (w*w*w),\n        2*torch.atan(n/w) / (n+1e-12))\n\n    theta = two_atan_nbyw_by_n * n\n    omega = two_atan_nbyw_by_n * vec\n    return omega, theta\n\ndef se3_expm(xi):\n    \"\"\" xi in \\se3 \"\"\"\n    tau, omega = torch.split(xi, [3, 3], axis=-1)\n    q, theta = so3_expm_and_theta(omega)\n\n\n    theta = theta[...,torch.newaxis,torch.newaxis]\n    theta = torch.tile(theta, \n        torch.concat([torch.ones_like(torch.shape(q)[:-1]), [3,3]], axis=-1))\n\n    theta_sq = theta * theta\n    Omega = hat(omega)\n    Omega_sq = torch.matmul(Omega, Omega)\n\n    Vs = torch.eye(3, batch_shape=torch.shape(xi)[:-1]) + \\\n         (1-torch.cos(theta)) / (theta_sq + 1e-12) * Omega + \\\n         (theta - torch.sin(theta)) / (theta_sq*theta + 1e-12) * Omega_sq\n\n    V = torch.where(theta<MIN_THETA, quaternion_to_matrix(q), Vs)\n    t = matdotv(V, tau)\n    return q, t\n\ndef se3_logm(so3, t):\n    omega, theta = so3_logm_and_theta(so3)\n    Omega = hat(omega)\n    Omega_sq = torch.matmul(Omega, Omega)\n\n    theta = theta[...,torch.newaxis]\n    theta = torch.tile(theta, \n        torch.concat([torch.ones_like(torch.shape(omega)[:-1]), [3,3]], axis=-1))\n    half_theta = 0.5*theta\n\n    Vinv_approx = torch.eye(3, batch_shape=torch.shape(omega)[:-1]) - \\\n        0.5*Omega + (1.0/12.0) * Omega_sq\n\n    Vinv_exact = torch.eye(3, batch_shape=torch.shape(omega)[:-1]) - \\\n        0.5*Omega + (1-theta*torch.cos(half_theta) / \\\n        (2*torch.sin(half_theta)+1e-12)) / (theta*theta + 1e-12) * Omega_sq\n\n    Vinv = torch.where(theta<MIN_THETA, Vinv_approx, Vinv_exact)\n    tau = matdotv(Vinv, t)\n\n    upsilon = torch.concat([tau, omega], axis=-1)\n    return upsilon\n\n\n### matrix functions ###\n\ndef se3_matrix_inverse(G):\n    \"\"\" Invert SE3 matrix \"\"\"\n    inp_shape = G.shape \n    G = torch.reshape(G, [-1, 4, 4])\n\n    R, t = G[:, :3, :3], G[:, :3, 3:]\n    R = R.permute(0, 2, 1)\n    t = -torch.matmul(R, t)\n\n    filler = torch.tensor([0.0, 0.0, 0.0, 1.0], device=G.device, dtype=G.dtype) \n    filler = torch.reshape(filler, [1, 1, 4])\n    filler = filler.repeat([G.shape[0], 1, 1]) \n\n    Ginv = torch.cat([R, t], dim=-1)\n    Ginv = torch.cat([Ginv, filler], dim=-2)\n    return torch.reshape(Ginv, inp_shape)\n\n\ndef _se3_matrix_expm_grad(grad):\n    grad_upsilon_omega = torch.stack([\n        grad[..., 0, 3],\n        grad[..., 1, 3],\n        grad[..., 2, 3],\n        grad[..., 2, 1] - grad[..., 1, 2],\n        grad[..., 0, 2] - grad[..., 2, 0],\n        grad[..., 1, 0] - grad[..., 0, 1]\n    ], axis=-1)\n\n    return grad_upsilon_omega\n\ndef _se3_matrix_expm_shape(op):\n    return [op.inputs[0].get_shape().as_list()[:-1] + [4, 4]]\n\n\ndef _se3_matrix_expm(upsilon_omega):\n    \"\"\" se3 matrix exponential se(3) -> SE(3), works for arbitrary batch dimensions\n    - Note: gradient is overridden with _se3_matrix_expm_grad, which approximates \n    gradient for small upsilon_omega\n    \"\"\"\n\n    eps=1e-12\n    inp_shape = upsilon_omega.shape \n    out_shape = list(inp_shape)[:-1]+[4,4] \n\n    upsilon_omega = torch.reshape(upsilon_omega, [-1, 6])\n    batch = upsilon_omega.shape[0]\n    v, w = torch.split(upsilon_omega, [3, 3], dim=-1)\n\n    theta_sq = torch.sum(w**2, dim=1 )\n    theta_sq = torch.reshape(theta_sq, [-1, 1, 1])\n\n    theta = torch.sqrt(theta_sq)\n    theta_po4 = theta_sq * theta_sq\n\n    wx = hat(w)\n    wx_sq = torch.matmul(wx, wx)\n    I = torch.eye(3, dtype=upsilon_omega.dtype, device=upsilon_omega.device).repeat([batch,1,1])\n\n    ### taylor approximations ###\n    R1 =  I + (1.0 - (1.0/6.0)*theta_sq + (1.0/120.0)*theta_po4) * wx + \\\n        (0.5 - (1.0/12.0)*theta_sq + (1.0/720.0)*theta_po4) * wx_sq\n    \n    V1 = I + (0.5 - (1.0/24.0)*theta_sq + (1.0/720.0)*theta_po4)*wx + \\\n        ((1.0/6.0) - (1.0/120.0)*theta_sq + (1.0/5040.0)*theta_po4)*wx_sq\n\n    ### exact values ###\n    R2 = I + (torch.sin(theta) / (theta+eps)) * wx +\\\n        ((1 - torch.cos(theta)) / (theta_sq+eps)) * wx_sq\n\n    V2 = I + ((1 - torch.cos(theta)) / (theta_sq + eps)) * wx + \\\n        ((theta - torch.sin(theta))/(theta_sq*theta + eps)) * wx_sq\n\n    # print(theta.shape, R1.shape, R2.shape, \">>>\", flush=True)\n    # R = torch.where(theta[:, 0, 0]<MIN_THETA, R1, R2)\n    # V = torch.where(theta[:, 0, 0]<MIN_THETA, V1, V2)\n    R = torch.where(theta<MIN_THETA, R1, R2)\n    V = torch.where(theta<MIN_THETA, V1, V2)\n\n    t = torch.matmul(V, v[...,None]) \n\n    fill = torch.tensor([0, 0, 0, 1], dtype=torch.float32, device=R.device)\n    fill = torch.reshape(fill, [1, 1, 4])\n    fill = fill.repeat([batch, 1, 1])\n\n    G = torch.cat([R, t], dim=2)\n    G = torch.cat([G, fill], dim=1)\n    G = torch.reshape(G, out_shape)\n    return G\n\n\nclass SE3_Matrix_Expm(torch.autograd.Function):\n    \n    @staticmethod\n    def forward(ctx, upsilon_omega):\n        G=_se3_matrix_expm(upsilon_omega)\n        # ctx.save_for_backward(G)\n        return G \n\n    @staticmethod\n    def backward(ctx, grad_output):\n        # result, = ctx.saved_tensors\n        \n        return _se3_matrix_expm_grad(grad_output)\n\n\ndef se3_matrix_expm(upsilon_omega):\n    return SE3_Matrix_Expm.apply(upsilon_omega)\n\n\ndef se3_matrix_increment(G, upsilon_omega):\n    \"\"\" Left increment of rigid body transformation: G = expm(xi) G\"\"\"\n    dG = se3_matrix_expm(upsilon_omega)\n    return torch.matmul(dG, G)"
  },
  {
    "path": "geometry/transformation.py",
    "content": "import torch  \nimport numpy as np\n\n# from core.config import cfg\nfrom config.default import get_cfg\nfrom .se3 import *\nfrom .intrinsics import *\nfrom . import projective_ops as pops\nfrom . import cholesky\n\nfrom .einsum import einsum\n\ncholesky_solve = cholesky.solve\n\n\nMIN_DEPTH = 0.1\nMAX_RESIDUAL = 250.0\n\n# can use both matrix or quaternions to represent rotations\nDEFAULT_INTERNAL = 'matrix'\n\n\ndef clip_dangerous_gradients(x):\n    return x\n\n\ndef jac_local_perturb(pt, fill=False):\n\n    X, Y, Z = torch.split(pt,[1,1,1], dim=-1)  # torch.split(pt, [1, 1, 1], axis=-1)\n    o, i = torch.zeros_like(X), torch.ones_like(X)\n    if fill:\n        j1 = torch.cat([i,  o,  o, o], dim=-1)\n        j2 = torch.cat([o,  i,  o, o], dim=-1)\n        j3 = torch.cat([o,  o,  i, o], dim=-1)\n        j4 = torch.cat([o, -Z,  Y, o], dim=-1)\n        j5 = torch.cat([Z,  o, -X, o], dim=-1)\n        j6 = torch.cat([-Y,  X,  o, o],dim=-1)\n    else:\n        j1 = torch.cat([i,  o,  o], dim=-1)\n        j2 = torch.cat([o,  i,  o], dim=-1)\n        j3 = torch.cat([o,  o,  i], dim=-1)\n        j4 = torch.cat([o, -Z,  Y], dim=-1)\n        j5 = torch.cat([Z,  o, -X], dim=-1)\n        j6 = torch.cat([-Y,  X,  o],dim=-1)\n    jac = torch.stack([j1, j2, j3, j4, j5, j6], dim=-1)\n    return jac\n\n\ndef cond_transform(cond, T1, T2):\n    \"\"\" Return T1 if cond, else T2 \"\"\"\n\n    if T1.internal == 'matrix':\n        mat = torch.cond(cond, lambda: T1.matrix(), lambda: T2.matrix())\n        T = T1.__class__(matrix=mat, internal=T1.internal)\n\n    elif T1.internal == 'quaternion':\n        so3 = torch.cond(cond, lambda: T1.so3, lambda: T2.so3)\n        translation = torch.cond(cond, lambda: T1.translation,\n                              lambda: T2.translation)\n        T = T1.__class__(so3=so3, translation=translation,\n                         internal=T1.internal)\n    return T\n\n\nclass SE3:\n    def __init__(self, upsilon=None, matrix=None, so3=None, translation=None, eq=None, internal=DEFAULT_INTERNAL):\n        self.eq = eq\n        self.internal = internal\n\n        if internal == 'matrix':\n            if upsilon is not None:\n                self.G = se3_matrix_expm(upsilon)\n            elif matrix is not None:\n                self.G = matrix\n        else:\n            raise NotImplementedError \n\n    def __call__(self, pt, jacobian=False):\n        \"\"\" Transform set of points \"\"\"\n\n        if self.internal == 'matrix':\n\n            pt = torch.cat([pt, torch.ones_like(pt[..., :1])],\n                        dim=-1)  # convert to homogenous\n            pt = einsum(self.eq, self.G[..., :3, :], pt)\n        else:\n            raise NotImplementedError\n\n        if jacobian:\n            jacobian = jac_local_perturb(pt)\n            return pt, jacobian\n\n        return pt\n\n    def __mul__(self, other):\n        if self.internal == 'matrix':\n            G = torch.matmul(self.G, other.G)\n            return self.__class__(matrix=G, internal=self.internal)\n        else:\n            raise NotImplementedError\n\n    def identity_(self):\n        if self.internal == 'matrix':\n            shape=self.G.shape\n            self.G=torch.eye(4, device=self.G.device).repeat([*shape[:-2],1,1])\n        else:\n            raise NotImplementedError\n\n\n    def increment(self, upsilon):\n        if self.internal == 'matrix':\n            G = se3_matrix_increment(self.G, upsilon)\n            return self.__class__(matrix=G, internal=self.internal)\n        else:\n            raise NotImplementedError\n\n    def concat(self, other, axis=0):\n        if self.internal == 'matrix':\n            G = torch.concat([self.G, other.G], axis=axis)\n        else:\n            raise NotImplementedError\n\n\n    def copy(self, stop_gradients=False):\n\n        if self.internal == 'matrix':\n            if stop_gradients:\n                # return self.__class__(matrix=torch.stop_gradient(self.G), internal=self.internal)\n                return self.__class__(matrix=self.G.detach(), internal=self.internal)\n            else:\n                return self.__class__(matrix=self.G, internal=self.internal)\n\n        else:\n            raise NotImplementedError\n\n    def to_vec(self):\n        return torch.concat([self.so3, self.translation], axis=-1)\n\n    def inv(self):\n        if self.internal == 'matrix':\n            Ginv = se3_matrix_inverse(self.matrix())\n            return self.__class__(matrix=Ginv, internal=self.internal)\n        else:\n            raise NotImplementedError\n\n    def adj(self):\n        if self.internal == 'matrix':\n            R = self.G[..., :3, :3]\n            t = self.G[..., :3, 3]\n            A11 = R\n            A12 = torch.matmul(hat(t), R)\n            A21 = torch.zeros_like(A11)\n            A22 = R\n        else:\n            raise NotImplementedError\n\n\n        Ax = torch.concat([\n            torch.concat([A11, A12], axis=-1),\n            torch.concat([A21, A22], axis=-1)\n        ], axis=-2)\n\n        return Ax\n\n    def logm(self):\n        return se3_logm(self.so3, self.translation)\n\n    def shape(self):\n        # return torch.shape(self.so3)[:-1]\n        if self.internal == 'matrix':\n            my_shape = self.G.shape  # torch.shape(self.G)\n        else:\n            raise NotImplementedError\n\n        return (my_shape[0], my_shape[1])\n\n    def matrix(self, fill=True):\n        if self.internal == 'matrix':\n            return self.G\n        else:\n            raise NotImplementedError\n       \n\n    def transform(self, depth, intrinsics, valid_mask=False, return3d=False):\n        \n        # pt = pops.backproject(depth, intrinsics)\n        pt = pops.backproject(depth, intrinsics)\n        pt_new = self.__call__(pt)\n        coords = pops.project(pt_new, intrinsics)\n        if return3d:\n            return coords, pt_new\n        if valid_mask:\n            vmask = (pt[..., -1] > MIN_DEPTH) & (pt_new[..., -1] > MIN_DEPTH)\n            # vmask = torch.cast(vmask, torch.float32)[..., torch.newaxis]\n            # vmask = vmask.to(dtype=torch.float32)[..., None, :,:] #BxKx1xHxW\n            vmask = vmask.to(dtype=torch.float32)[..., :, :, None]  # BxKx1xHxW\n            return coords, vmask\n        return coords\n\n    def induced_flow(self, depth, intrinsics, valid_mask=False):\n        coords0 = pops.coords_grid(depth, homogeneous=False)\n\n        if valid_mask:\n            coords1, vmask = self.transform(\n                depth, intrinsics, valid_mask=valid_mask)\n            return coords1 - coords0, vmask\n        coords1 = self.transform(depth, intrinsics, valid_mask=valid_mask)\n        return coords1 - coords0\n\n    def depth_change(self, depth, intrinsics):\n        pt = pops.backproject(depth, intrinsics)\n        pt_new = self.__call__(pt)\n        return pt_new[..., -1] - pt[..., -1]\n    \n    def identity(self):\n        \"\"\" Push identity transformation to start of collection \"\"\"\n        batch, frames = self.shape()\n        if self.internal == 'matrix':\n            # I = torch.eye(4, batch_shape=[batch, 1])\n            I = torch.eye(4, dtype=self.G.dtype, device=self.G.device).repeat(\n                [batch, 1, 1, 1])\n            # return self.__class__(matrix=I, internal=self.internal, eq=self.eq)\n            return self.__class__(matrix=I, internal=self.internal, eq=self.eq)\n        else:\n            raise NotImplementedError\n\n\n\n\nclass SE3Sequence(SE3):\n    \"\"\" Stores collection of SE3 objects \"\"\"\n\n    def __init__(self, upsilon=None, matrix=None, so3=None, translation=None, eq= \"aijk,ai...k->ai...j\",internal=DEFAULT_INTERNAL):\n        super().__init__(\n            upsilon, matrix, so3, translation, internal=internal, eq=eq)\n\n        # self.eq = \"aijk,ai...k->ai...j\"\n    def __call__(self, pt, inds=None, jacobian=False):\n        if self.internal == 'matrix':\n            return super().__call__(pt, jacobian=jacobian)\n        else:\n            raise NotImplementedError\n\n\n    def gather(self, inds):\n        if self.internal == 'matrix':\n            G = torch.index_select(self.G, index=inds, dim=1)\n            return SE3Sequence(matrix=G, internal=self.internal)\n        else:\n            raise NotImplementedError\n\n    # def append_identity(self):\n    #     \"\"\" Push identity transformation to start of collection \"\"\"\n    #     batch, frames = self.shape()\n    #     if self.internal == 'matrix':\n    #         # I = torch.eye(4, batch_shape=[batch, 1])\n    #         I = torch.eye(4, dtype=self.G.dtype, device=self.G.device).repeat(\n    #             [batch, 1, 1, 1])\n\n    #         G = torch.cat([I, self.G], dim=1)\n    #         return SE3Sequence(matrix=G, internal=self.internal)\n    #     else:\n    #         raise NotImplementedError\n\n    def reprojction_optim(self,\n                       target,\n                       weight,\n                       depth,\n                       intrinsics,\n                       num_iters=2,\n                       depth_img_coords=None\n                       ):\n\n        target = clip_dangerous_gradients(target).to(dtype=torch.float64)\n        weight = clip_dangerous_gradients(weight).to(dtype=torch.float64)\n\n        X0 = pops.backproject(depth, intrinsics, depth_coords=depth_img_coords)\n        w = weight[..., None] \n\n        lm_lmbda = get_cfg(\"LM\").LM_LMBDA\n        ep_lmbda = get_cfg(\"LM\").EP_LMBDA\n\n        T = self.copy(stop_gradients=False)\n        for i in range(num_iters):\n            ### compute the jacobians of the transformation ###\n            X1, jtran = T(X0, jacobian=True)\n            x1, (jproj, jkvec) = pops.project(X1, intrinsics, jacobian=True)\n\n            v = (X0[..., -1] > MIN_DEPTH) & (X1[..., -1] > MIN_DEPTH)\n            # v = v.to(dtype=torch.float32)[..., None, None]\n            v = v.to(dtype=torch.float64)[..., None, None]\n\n            ### weighted gauss-newton update ###\n            J = einsum('...ij,...jk->...ik', jproj.to(dtype=torch.float64), jtran.to(dtype=torch.float64 ))  \n\n            H = einsum('ai...j,ai...k->aijk', v*w*J, J)\n            b = einsum('ai...j,ai...->aij', v*w*J, target-x1)\n\n            ### add dampening and apply increment ###\n            H += ep_lmbda*torch.eye(6, dtype=H.dtype, device=H.device) + lm_lmbda*H*torch.eye(6,dtype=H.dtype, device=H.device)\n            try:\n                delta_upsilon = cholesky_solve(H, b)\n            except:\n                # print(w.shape,v.shape, w.mean(), v.mean(),H,b, '!!!!')\n                raise\n            T = T.increment(delta_upsilon)\n\n        # update\n        if self.internal == 'matrix':\n            self.G = T.matrix()\n            T = SE3Sequence(\n                matrix=T.matrix(), internal=self.internal)\n        else:\n            raise NotImplementedError\n\n        return T\n\n\n    def transform(self, depth, intrinsics, valid_mask=False, return3d=False):\n        return super().transform(depth, intrinsics, valid_mask, return3d)\n"
  },
  {
    "path": "model/CFNet.py",
    "content": "\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom thirdparty.raft.update import BasicUpdateBlock\nfrom thirdparty.raft.extractor import BasicEncoder\nfrom thirdparty.raft.corr import CorrBlock, AlternateCorrBlock\nfrom thirdparty.raft.utils.utils import bilinear_sampler, coords_grid, upflow\n\ntry:\n    autocast = torch.cuda.amp.autocast\nexcept:\n    # dummy autocast for PyTorch < 1.6\n    class autocast:\n        def __init__(self, enabled):\n            pass\n        def __enter__(self):\n            pass\n        def __exit__(self, *args):\n            pass\n\n\nclass ImageFeaEncoder(nn.Module):\n    def __init__(self, input_dim=3, output_dim=256):\n        super().__init__()\n        self.fnet = BasicEncoder(output_dim=output_dim, norm_fn='instance', dropout=False, input_dim=input_dim)        \n\n        if 1:#self.args.pretrained_model is not None:\n            print(\"Loading the weights of RAFT...\")\n            import os             \n            self.load_state_dict(\n                #  torch.load(self.args.pretrained_model, map_location='cpu'), strict=False\n                 torch.load( f\"{os.path.dirname(os.path.abspath(__file__)) }/../weights/img_fea_enc.pth\", map_location='cpu'), strict=True\n            )\n        else:\n            print(\"ImageFeaEncoder will be trained from scratch...\")\n\n    def forward(self, image1, image2):\n        image1 = 2 * (image1 / 255.0) - 1.0\n        image2 = 2 * (image2 / 255.0) - 1.0\n\n        image1 = image1.contiguous()\n        image2 = image2.contiguous()\n        with autocast(enabled=True):\n            fmap1, fmap2 = self.fnet([image1, image2])\n        return fmap1, fmap2\n\n\nclass GRU_CFUpdator(nn.Module):\n    def __init__(self, args):\n        super().__init__()\n        self.args = args\n        input_dim =  args.get(\"input_dim\", 3)\n\n        self.hidden_dim = hdim = 128\n        self.context_dim = cdim = 128\n        args.corr_levels = 4\n        args.corr_radius = 4\n\n        if 'alternate_corr' not in self.args:\n            self.args.alternate_corr = False\n\n        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)\n\n        if self.args.pretrained_model is not None:\n            print(\"Loading the weights of RAFT...\")\n            import os             \n            self.load_state_dict(\n                #  torch.load(self.args.pretrained_model, map_location='cpu'), strict=False\n                 torch.load( f\"{os.path.dirname(os.path.abspath(__file__)) }/../weights/gru_update.pth\", map_location='cpu'), strict=True\n            )\n        else:\n            print(\"GRU_CFUpdator will be trained from scratch...\")\n\n\n    \n    \n    def freeze_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.eval()\n\n    def initialize_flow(self, img, downsample_rate=8):\n        \"\"\" Flow is represented as difference between two coordinate grids flow = coords1 - coords0\"\"\"\n        N, C, H, W = img.shape\n        coords0 = coords_grid(N, H//downsample_rate, W//downsample_rate).to(img.device)\n        coords1 = coords_grid(N, H//downsample_rate, W//downsample_rate).to(img.device)\n\n        # optical flow computed as difference: flow = coords1 - coords0\n        return coords0, coords1\n\n    def upsample_flow(self, flow, mask, upsample_scale=8):\n        \"\"\" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination \"\"\"\n        N, _, H, W = flow.shape\n        mask = mask.view(N, 1, 9, upsample_scale, upsample_scale, H, W)\n        mask = torch.softmax(mask, dim=2)\n\n        up_flow = F.unfold(upsample_scale * flow, [3,3], padding=1)\n        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)\n\n        up_flow = torch.sum(mask * up_flow, dim=2)\n        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)\n        return up_flow.reshape(N, 2, upsample_scale*H, upsample_scale*W)\n\n\n    def forward(self, fmap1, fmap2, iters=1, flow_init=None, upsample=True, test_mode=False, context_fea=None, update_corr_fn=True):\n        \"\"\" Estimate optical flow between pair of frames \"\"\"\n\n        hdim = self.hidden_dim\n        cdim = self.context_dim\n\n        if update_corr_fn: # need carful handling outside\n            # run the feature network\n            self.fmap1 = fmap1.float()\n            self.fmap2 = fmap2.float()\n            if self.args.alternate_corr:\n                self.corr_fn = AlternateCorrBlock(self.fmap1, self.fmap2, radius=self.args.corr_radius)\n            else:\n                self.corr_fn = CorrBlock(self.fmap1, self.fmap2, radius=self.args.corr_radius)\n\n        if update_corr_fn: \n            # run the context network\n            with autocast(enabled=self.args.mixed_precision):\n                assert context_fea is not None\n                ds = context_fea.shape[-1]//self.fmap1.shape[-1]\n                cnet = F.interpolate(context_fea, scale_factor=1/ds, mode='bilinear', align_corners=True)\n\n                self.net, self.inp = torch.split(cnet, [hdim, cdim], dim=1)\n                self.net = torch.tanh(self.net)\n                self.inp = torch.relu(self.inp)\n\n        # coords0, coords1 = self.initialize_flow(image1)\n        coords0, coords1 = self.initialize_flow(flow_init)\n\n        if flow_init is not None:\n            ds = flow_init.shape[-1]//coords0.shape[-1]\n            if ds !=1:\n                flow_init /=ds\n                flow_init = F.interpolate(flow_init, scale_factor=1/ds, mode='bilinear', align_corners=True)\n\n            coords1 = coords1 + flow_init\n\n        flow_predictions = []\n        for itr in range(iters):\n            coords1 = coords1.detach()\n            corr = self.corr_fn(coords1) # index correlation volume\n\n            flow = coords1 - coords0\n            with autocast(enabled=self.args.mixed_precision):\n                # net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)\n                self.net, up_mask, delta_flow = self.update_block(self.net, self.inp, corr, flow)\n\n            # F(t+1) = F(t) + \\Delta(t)\n            coords1 = coords1 + delta_flow\n\n            # upsample predictions\n            if up_mask is None:\n                flow_up = upflow(coords1 - coords0, scale=image1.shape[2]//coords0.shape[2],)\n            else:\n                if self.args.fea_net in [\"bigdx4\"]:\n                    flow_up = self.upsample_flow(coords1 - coords0, up_mask, upsample_scale=4)\n                else:\n                    flow_up = self.upsample_flow(coords1 - coords0, up_mask)\n            \n            flow_predictions.append(flow_up)\n\n        if test_mode:\n            return coords1 - coords0, flow_up\n            \n        return flow_predictions\n\n"
  },
  {
    "path": "model/HybridNet.py",
    "content": "\nimport torch \nimport torch.nn as nn \n\nfrom thirdparty.kpconv.kpconv_blocks import *\nimport torch.nn.functional as F\nimport numpy as np\nfrom kpconv.lib.utils import square_distance\nfrom model.descriptor2D import  SuperPoint2D\nfrom model.descriptor3D import  KPSuperpoint3Dv2\n\n\n\nREGISTERED_HYBRID_NET_CLASSES={}\ndef register_hybrid_net(cls, name=None):\n    global REGISTERED_HYBRID_NET_CLASSES\n    if name is None:\n        name = cls.__name__\n    assert name not in REGISTERED_HYBRID_NET_CLASSES, f\"exist class: {REGISTERED_HYBRID_NET_CLASSES}\"\n    REGISTERED_HYBRID_NET_CLASSES[name] = cls\n    return cls\n\n\ndef get_hybrid_net(name):\n    global REGISTERED_HYBRID_NET_CLASSES\n    assert name in REGISTERED_HYBRID_NET_CLASSES, f\"available class: {REGISTERED_HYBRID_NET_CLASSES}\"\n    return REGISTERED_HYBRID_NET_CLASSES[name]\n\nclass ContextFeatureNet(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.context_fea_extractor_3d= KPSuperpoint3Dv2(config['context_fea_extractor_3d'] )\n    \n    def forward(self, batch):\n        # x = batch['features'].clone().detach()\n        # assert len(batch['stack_lengths'][-1])==1, \"Only support bs=1 for now\" \n        len_src_c = batch['stack_lengths'][-1][0]\n        pcd_c = batch['model_points'][-1]\n        pcd_c = pcd_c[:len_src_c]\n\n        image=batch['image']\n\n        ############### encode 3d and 2d features ###############\n        batch3d={\n            'points': batch['model_points'], \n            'neighbors': batch['neighbors'], \n            'pools':  batch['pools'], \n            'upsamples': batch['upsamples'],\n            'features': batch['model_point_features'], \n            'stack_lengths': batch['stack_lengths'],\n        }\n        ctx_descriptors_3d = self.context_fea_extractor_3d(batch3d)\n\n\n        return {\n            \"ctx_fea_3d\":ctx_descriptors_3d,\n        }\n\n\n\n@register_hybrid_net\nclass HybridDescNet(nn.Module):\n    #independent 2d and 3d network\n    def __init__(self, config):\n        super().__init__()\n\n        self.corr_fea_extractor_2d= SuperPoint2D(config['keypoints_detector_2d'] )\n        self.corr_fea_extractor_3d= KPSuperpoint3Dv2(config['keypoints_detector_3d'] )\n        self.descriptors_3d = {}\n\n\n    def forward(self, batch):\n        assert len(set(batch['class_name']))==1, \"A batch should contain data of the same class.\"\n        class_name = batch['class_name'][0]\n\n        len_src_c = batch['stack_lengths'][-1][0]\n        pcd_c = batch['model_points'][-1]\n        pcd_c = pcd_c[:len_src_c]#, pcd_c[len_src_c:]\n\n        image=batch['image']\n\n        ############### encode 3d and 2d features ###############\n        batch3d={\n            'points': batch['model_points'], \n            'neighbors': batch['neighbors'], \n            'pools':  batch['pools'], \n            'upsamples': batch['upsamples'],\n            'features': batch['model_point_features'], \n            'stack_lengths': batch['stack_lengths'],\n        }\n        if self.training:\n            self.descriptors_3d[class_name] = self.corr_fea_extractor_3d(batch3d)\n        else:\n            if class_name not in self.descriptors_3d:\n                self.descriptors_3d[class_name] = self.corr_fea_extractor_3d(batch3d)\n\n        descriptors_2d = self.corr_fea_extractor_2d(image)['descriptors']\n\n\n        return {\n            \"descriptors_2d\":descriptors_2d,\n            \"descriptors_3d\":self.descriptors_3d[class_name],\n            \"scores_saliency_3d\": None, \n            \"scores_overlap_3d\":None, \n\n        }\n"
  },
  {
    "path": "model/PoseRefiner.py",
    "content": "import os \nimport time\nimport cv2\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport numpy as np\nfrom easydict import EasyDict as edict\nfrom functools import partial\n\nfrom geometry.transformation import *\nfrom geometry.intrinsics import *\nfrom geometry.projective_ops import coords_grid, normalize_coords_grid\nfrom model.CFNet import GRU_CFUpdator , ImageFeaEncoder\nfrom utils.pose_utils import pose_padding\nfrom config.default import get_cfg\n\n\n\nEPS = 1e-5\nMIN_DEPTH = 0.1\nMAX_ERROR = 100.0\n\n# exclude extremly large displacements\nMAX_FLOW = 400\n\n\ndef raft_sequence_flow_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):\n    \"\"\" Loss function defined over sequence of flow predictions \"\"\"\n\n    n_predictions = len(flow_preds)    \n    flow_loss = 0.0\n    \n    # exlude invalid pixels and extremely large diplacements\n    mag = torch.sum(flow_gt**2, dim=1).sqrt()\n    \n    valid = (valid >= 0.5) & (mag < max_flow)\n\n    for i in range(n_predictions):\n        i_weight = gamma**(n_predictions - i - 1)\n        i_loss = (flow_preds[i] - flow_gt).abs()\n        flow_loss += i_weight * (valid[:, None] * i_loss).mean()\n\n    epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()\n    epe = epe.view(-1)[valid.view(-1)]\n\n    metrics = {\n        'epe': epe.mean().item(),\n        '1px': (epe < 1).float().mean().item(),\n        '3px': (epe < 3).float().mean().item(),\n        '5px': (epe < 5).float().mean().item(),\n    }\n\n    return flow_loss, metrics\n\n\n\n\nclass PoseRefiner(nn.Module):\n    def __init__(self, cfg,\n                 reuse=False,\n                 schedule=None,\n                 use_regressor=True,\n                 is_calibrated=True,\n                 bn_is_training=False,\n                 is_training=True,\n                 renderer=None,\n                 ):\n\n        super().__init__()\n\n        self.legacy=True\n        self.cfg = cfg\n        self.reuse = reuse\n        self.sigma=nn.ParameterList( [nn.Parameter(torch.ones(1)*1 )] )\n        self.with_corr_weight = self.cfg.get(\"with_corr_weight\", True)\n        if not self.with_corr_weight:\n            print(\"Warning: the correlation weighting is disabled.\")\n\n        self.is_calibrated = cfg.IS_CALIBRATED\n        if not is_calibrated:\n            self.is_calibrated = is_calibrated\n\n        self.is_training = is_training\n        self.use_regressor = use_regressor\n\n        self.residual_pose_history = []\n        self.Ti_history =[]\n        self.coords_history = []\n        self.residual_history = []\n        self.inds_history = []\n        # self.weights_history = []\n        self.flow_history = []\n        self.intrinsics_history = []\n        self.summaries = []\n\n        if self.cfg.FLOW_NET=='raft':\n            self.image_fea_enc= ImageFeaEncoder()\n            self.cf_net = GRU_CFUpdator(self.cfg.raft) \n        else:\n            raise NotImplementedError \n        self.renderer = renderer  \n\n    def _clear(self,):\n        self.residual_pose_history = []\n        self.Ti_history =[]\n        self.coords_history = []\n        self.residual_history = []\n        self.inds_history = []\n        # self.weights_history = []\n        self.flow_history = []\n        self.intrinsics_history = []\n        self.summaries = []\n\n    def __len__(self):\n        return len(self.residual_pose_history)\n\n    def render(self, params, render_tex=False):\n        \"\"\" render a batch of images given the intrinsic and extrinsic params \n\n        Args:\n            params.K: np.array, of shape Bx3x3\n            params.camera_extrinsics: np.array, of shape Bx3x4\n\n        Returns:\n            [type]: [description]\n        \"\"\"\n\n        bs=params.K.shape[0]\n        colors=[]\n        depths=[]\n\n        color,depth= self.renderer( params.obj_cls, params.vert_attribute, T=params.camera_extrinsics, K=params.K, \n            render_image_size=params.render_image_size, near=0.1, far=6, render_tex=render_tex)\n        #set the invalid values to zeros\n        depth[depth==-1] = 0\n        return {\n            # 1x3xHxW\n            \"syn_img\": color, \n            \"syn_depth\": depth.detach(), # 1x1xHxW\n        }\n\n\n    def get_affine_transformation(self, mask, crop_center, with_intrinsic_transform=False, output_size=None, margin_ratio=0.4):\n        B,_,H,W = mask.shape\n        ratio = float(H) / float(W)\n        affine_matrices = []\n        intrinsic_matrices = []\n\n        for b in range(B):\n            zoom_c_x, zoom_c_y = crop_center[b] #crop_center\n\n            ys, xs = np.nonzero(mask[b][0].detach().cpu().numpy() )\n            if len(ys)>0 and len(xs)>0:\n                obj_imgn_start_x = xs.min() \n                obj_imgn_start_y = ys.min()\n                obj_imgn_end_x = xs.max()\n                obj_imgn_end_y = ys.max()\n            else:\n                obj_imgn_start_x=0\n                obj_imgn_start_y=0\n                obj_imgn_end_x=0\n                obj_imgn_end_y=0\n\n\n            # mask region\n            left_dist = zoom_c_x - obj_imgn_start_x\n            right_dist = obj_imgn_end_x - zoom_c_x\n            up_dist = zoom_c_y - obj_imgn_start_y\n            down_dist = obj_imgn_end_y - zoom_c_y\n            # crop_height = np.max([ratio * right_dist, ratio * left_dist, up_dist, down_dist]) * 2 * 1.4\n            crop_height = np.max([ratio * right_dist, ratio * left_dist, up_dist, down_dist]) * 2 * (1+margin_ratio)\n            crop_width = crop_height / ratio\n\n            # affine transformation for PyTorch\n            x1 = (zoom_c_x - crop_width / 2) * 2 / W - 1;\n            x2 = (zoom_c_x + crop_width / 2) * 2 / W - 1;\n            y1 = (zoom_c_y - crop_height / 2) * 2 / H - 1;\n            y2 = (zoom_c_y + crop_height / 2) * 2 / H - 1;\n\n            pts1 = np.float32([[x1, y1], [x1, y2], [x2, y1]])\n            pts2 = np.float32([[-1, -1], [-1, 1], [1, -1]])\n            affine_matrix = torch.tensor(cv2.getAffineTransform(pts2, pts1), device=mask.device, dtype=torch.float32)\n            affine_matrices.append(affine_matrix)\n\n\n            if with_intrinsic_transform:\n                # affine transformation for PyTorch\n                x1 = (zoom_c_x - crop_width / 2)\n                x2 = (zoom_c_x + crop_width / 2)\n                y1 = (zoom_c_y - crop_height / 2)\n                y2 = (zoom_c_y + crop_height / 2)\n\n                pts1 = np.float32([[x1, y1], [x1, y2], [x2, y1]])\n                # pts2 = np.float32([[0, 0], [0, H-1], [W-1, 0]])\n                pts2 = np.float32([[0, 0], [0, output_size[0]-1], [output_size[1]-1, 0]])\n                # pts2 = np.float32([[0, 0], [0, 1], [1, 0]])\n                intrinsic_matrix = torch.tensor(cv2.getAffineTransform(pts2, pts1), device=mask.device, dtype=torch.float32)\n                intrinsic_matrices.append(intrinsic_matrix)\n                \n        if with_intrinsic_transform:\n            return  torch.stack(affine_matrices, dim=0), torch.stack(intrinsic_matrices, dim=0)\n        else:\n            return torch.stack(affine_matrices, dim=0)\n\n    def gen_zoom_crop_grids(self, fg_mask, K, T, output_size, model_center=[0,0,0], margin_ratio=0.4):\n        ##Get the projected model center in image (assuming the model is zero-centered, which should be reconsidered!)  \n        crop_center=K@T[:,:3,3:]\n        crop_center = crop_center[:,:2]/crop_center[:,2:3]\n\n        ##calculate affine transformation parameters\n        affine_matrices, crop_intrinsic_transform=self.get_affine_transformation(fg_mask, crop_center=crop_center.detach().cpu().numpy(), with_intrinsic_transform=True, output_size=(output_size[-2], output_size[-1]),margin_ratio=margin_ratio ) \n        grids = F.affine_grid(affine_matrices, torch.Size(output_size) )\n        ##Get cropped intrinsic_transform\n        intrinsics_crop= torch.inverse(pose_padding(crop_intrinsic_transform) )@K\n\n        return grids, intrinsics_crop\n\n    # def forward(self, image, Ts, intrinsics, fea_3d=None, inds=None, Tj_gt=None, obj_cls=None, geofea_3d=None, geofea_2d=None):\n    def forward(self, image, Ts, intrinsics, fea_3d=None, Tj_gt=None, obj_cls=None, geofea_3d=None, geofea_2d=None):\n        #clear the history data\n        self._clear()\n        cfg = self.cfg\n\n        RANDER_IMAGE_SIZE = get_cfg(\"BASIC\").render_image_size\n        ZOOM_CROP_SIZE=get_cfg(\"BASIC\").zoom_crop_size \n\n        if cfg.RESCALE_IMAGES:\n            images = 2 * (images / 255.0) - 1.0\n\n        Tij_gt=[]\n        syn_imgs=[]\n        syn_depths=[]\n\n        Ti = Ts\n        Tij = Ti.copy().identity()\n\n        for ren_iter in range(cfg.RENDER_ITER_COUNT):\n            # update rendering params\n            Ti = Tij*Ti # accumulate Ti \n            Tij.identity_() #set Tij to identity matrix at the begining of each ren_iter\n            if self.legacy:\n                Tij = Ti*Ti.inv() #set Tij to identity matrix at the begining of each ren_iter\n\n            render_params = edict({\n                    \"K\": intrinsics.detach(), \n                    \"camera_extrinsics\": Ti.matrix().detach().squeeze(1), \n                    \"obj_cls\": obj_cls,\n                    \"render_image_size\": RANDER_IMAGE_SIZE, \n                })\n\n            pc_depth = self.renderer.render_pointcloud(obj_cls, T=render_params.camera_extrinsics, K=render_params.K, \n                                    render_image_size=render_params.render_image_size)\n\n\n            if self.cfg.ONLINE_CROP:\n                #get the forground mask \n                fg_mask = pc_depth>0\n                B,C,_,_= pc_depth.size()\n                \n                ############### Get zoom parameters ###############\n                grids, intrinsics_crop = self.gen_zoom_crop_grids(fg_mask, render_params.K, render_params.camera_extrinsics, output_size=[B,C, *ZOOM_CROP_SIZE], model_center=None)\n\n                ############### Render reference images ###############\n                # Concatentate the 3D ctx feature \"fea_3d\" and 3d descriptor \"geofea_3d\" for feature rendering\n                if geofea_3d is not None:\n                    fea_3d_cat = torch.cat([fea_3d, geofea_3d ], dim=-1) # BxNxC\n                else:\n                    fea_3d_cat = fea_3d\n                render_params.vert_attribute = fea_3d_cat #fea_3d\n                render_params.K = intrinsics_crop.detach()\n                render_params.render_image_size = ZOOM_CROP_SIZE\n                ren_res = self.render(render_params, render_tex=True) \n                if geofea_3d is not None:\n                    syn_img, cfea, geofea1 = torch.split(ren_res['syn_img'],[3, fea_3d.shape[-1], geofea_3d.shape[-1] ] ,dim=1)\n                    syn_depth = ren_res['syn_depth']\n                    \n                else:\n                    syn_img, cfea = torch.split(ren_res['syn_img'], [3, fea_3d.shape[-1] ] ,dim=1)\n                    geofea1=None\n                    syn_depth = ren_res['syn_depth']\n                cfea = cfea*0.1 # balance the learning rate with the scale 0.1\n\n                ## Crop and zoom images\n                syn_image_crop = syn_img \n                image_crop= F.grid_sample(image, grids)\n                cfea_crop= cfea \n                if geofea1 is not None and geofea_2d is not None:\n                    # geofea1_crop = F.grid_sample(geofea1, grids)\n                    geofea1_crop = geofea1\n                    geofea2_crop = F.grid_sample(geofea_2d, grids)\n\n                #Render again to get more accurate depth for supervisions in losses \n                #TODO: could be merged into the rendering process above. -> Done\n                if self.legacy:#self.training:\n                    depth_render_params=edict({\n                            \"K\": intrinsics_crop.detach(), \n                            \"camera_extrinsics\": Ti.matrix().detach().squeeze(1), \n                            \"obj_cls\": obj_cls,\n                            \"render_image_size\": ZOOM_CROP_SIZE, \n                        })\n                    syn_depth=self.renderer.render_depth(obj_cls, T=depth_render_params.camera_extrinsics, K=depth_render_params.K, \n                                    render_image_size=depth_render_params.render_image_size, near=0.1, far=6)\n\n            #for visualization only\n            syn_imgs.append(syn_image_crop)\n            syn_imgs.append(image_crop)\n\n            # encode image features\n            feats1, feats2=self.image_fea_enc(syn_image_crop, image_crop)\n            # depths = torch.index_select(syn_depth, index=ii, dim=1) + EPS\n            depths = syn_depth+EPS\n\n            for i in range(cfg.ITER_COUNT):\n                # save for loss calculation \n                self.intrinsics_history.append(intrinsics_crop)\n                syn_depths.append(syn_depth)\n\n                Tij = Tij.copy(stop_gradients=True)\n                intrinsics_crop = intrinsics_crop.detach()\n                \n                #Get the projection in frame j of visible model points in frame i with the current relative pose estimation Tij\n                reproj_coords, vmask = Tij.transform(\n                    depths, intrinsics_crop, valid_mask=True)\n\n                uniform_grids = coords_grid(depths)\n                flow_init = torch.einsum( \"...ijk->...kij\", reproj_coords-uniform_grids[..., :2] ) * (depths>EPS) \n                flow = self.cf_net(feats1, feats2, flow_init=flow_init.squeeze(1), context_fea=cfea_crop, update_corr_fn=i==0)\n\n\n                self.flow_history.append(flow)\n\n                # Get the correspondences in frame j for each point in frame i, based on the current flow estimates\n                if isinstance (flow, (list, tuple)): # flow net may return a list of flow maps\n                    correspondence_target = torch.einsum(\"...ijk->...jki\", flow[-1]) + uniform_grids[..., :2]\n                else:\n                    correspondence_target = torch.einsum(\"...ijk->...jki\", flow) + uniform_grids[..., :2]\n                \n                # Optimize for the pose by minimizing errors between the constructed correspondence field \n                # (with the currently estimated pose) and the estimated correspondence field \n                if self.with_corr_weight and geofea1 is not None and geofea_2d is not None:\n                    geofea2_crop_warpped =  F.grid_sample(geofea2_crop, normalize_coords_grid(correspondence_target).squeeze(1) )\n                    corr_weight= torch.sum(geofea1_crop*geofea2_crop_warpped, dim=1,keepdim=True).permute(0,2,3,1)[:,None] #insert frame axis\n                    corr_weight = torch.exp(-torch.abs(1-corr_weight)/self.sigma[0]) * (syn_depth>0)[...,None].float()\n                else: \n                    corr_weight = weight\n                \n                Tij = Tij.reprojction_optim(\n                    correspondence_target, corr_weight, depths, intrinsics_crop, num_iters=cfg.OPTIM_ITER_COUNT )\n\n                reproj_coords, vmask1 = Tij.transform(\n                    depths, intrinsics_crop, valid_mask=True)\n\n                # For the loss calculation later\n                self.residual_pose_history.append(Tij)\n\n                self.Ti_history.append(Ti.copy(stop_gradients=True) )\n                Tij_gt.append( (Tj_gt*Ti.inv()).copy(stop_gradients=True) )\n\n                self.residual_history.append(\n                    vmask*vmask1*(reproj_coords-correspondence_target))  # BxKxHxWx3\n\n        # The final update of Ti \n        Ti = Tij*Ti\n        return {\n            \"Tij\": Tij,\n            \"Ti_pred\": Ti,\n            \"intrinsics\": intrinsics,\n            \"flow\": self.flow_history[0],\n            \"vmask\": syn_depth>0, \n            \"weight\": torch.einsum(\"...ijk->...kij\", corr_weight), \n            \"syn_depth\": syn_depths, #ren_res['syn_depth'],\n            \"syn_img\": syn_imgs+[image_crop, cfea_crop[:,:3]*10,geofea1[:,:3], geofea2_crop[:,:3]],\n            \"Tij_gt\" : Tij_gt\n        }\n\n    def compute_loss(self, Tij_gts, depths, intrinsics, loss='l1', log_error=True, loss3d=None, ):\n\n        total_loss = 0.0\n        for i in range(len(self.residual_pose_history)):\n            intrinsics = self.intrinsics_history[i]\n\n            depth, intrinsics = rescale_depths_and_intrinsics(depths[i], intrinsics, downscale=1)\n\n            Tij = self.residual_pose_history[i]\n\n            Gij = Tij_gts[i] \n\n            # intrinsics_pred = intrinsics\n            zstar = depth + EPS\n            flow_pred, valid_mask_pred = Tij.induced_flow(\n                zstar, intrinsics, valid_mask=True)  \n            flow_star, valid_mask_star = Gij.induced_flow(\n                zstar, intrinsics, valid_mask=True)\n\n            valid_mask = valid_mask_pred * valid_mask_star\n\n            #3D alignment loss \n            loss_3d_proj = 0 \n            if loss3d is not None:\n                Tj_pred=Tij*self.Ti_history[i]\n                Tj_gt=Gij*self.Ti_history[i]\n                loss_3d_proj = loss3d(\n                    R_pred=Tj_pred.G[:, 0, :3, :3], t_pred=Tj_pred.G[:, 0, :3, 3], R_tgt=Tj_gt.G[:, 0, :3, :3], t_tgt=Tj_gt.G[:, 0, :3, 3])\n\n            # flow loss\n            if isinstance( self.flow_history[0], (list, tuple)):\n                #squeeze the frame dimmension\n                self.flow_history[i] = [self.flow_history[i][f].squeeze(1) for f in range(len(self.flow_history[i])) ]\n                flow_mask = valid_mask.squeeze(1).squeeze(-1)\n                loss_flow,_ = raft_sequence_flow_loss(self.flow_history[i], flow_gt=torch.einsum(\"...ijk->...kij\", flow_star.squeeze(1)), valid= flow_mask, gamma=0.8, max_flow=MAX_FLOW)\n            else:\n                raise NotImplementedError\n            \n            # reprojection loss \n            reproj_diff = valid_mask * \\\n                torch.clamp(\n                    torch.abs(flow_pred - flow_star), -MAX_ERROR, MAX_ERROR)\n            reproj_loss = torch.mean(reproj_diff)\n\n\n            total_loss +=self.cfg.get(\"TRAIN_PCALIGN_WEIGHT\", 1)*loss_3d_proj+ self.cfg.TRAIN_FLOW_WEIGHT* loss_flow + self.cfg.TRAIN_REPROJ_WEIGHT*reproj_loss\n        \n        # clear the intermediate values\n        self._clear()\n\n        return {\n            \"total_loss\": total_loss,\n            \"reproj_loss\": reproj_loss,\n            \"flow_loss\":loss_flow,\n            \"loss_3d_proj\": loss_3d_proj,\n            \"valid_mask\": valid_mask,\n            \"Tij\": Tj_pred.G,\n            \"Gij\": Tj_gt.G\n        }\n"
  },
  {
    "path": "model/RNNPose.py",
    "content": "#\nimport time\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport torch.distributed as dist\nimport apex\nimport numpy as np\nimport os\nfrom easydict import EasyDict as edict\nfrom transforms3d.euler import mat2euler, euler2mat, euler2quat, quat2euler\nfrom functools import partial\n\n\n\nfrom model.HybridNet import HybridDescNet,ContextFeatureNet\nfrom thirdparty.kpconv.lib.utils import square_distance\nfrom model.PoseRefiner import PoseRefiner  \n\nfrom utils.pose_utils import pose_padding\nfrom geometry.transformation import SE3Sequence\n\nfrom geometry.diff_render_optim import DiffRendererWrapper\nfrom config.default import get_cfg\nfrom utils.util import dict_recursive_op\n\n# from model.RNNPose import register_posenet\n\nREGISTERED_NETWORK_CLASSES = {}\n\n\ndef register_posenet(cls, name=None):\n    global REGISTERED_NETWORK_CLASSES\n    if name is None:\n        name = cls.__name__\n    assert name not in REGISTERED_NETWORK_CLASSES, f\"exist class: {REGISTERED_NETWORK_CLASSES}\"\n    REGISTERED_NETWORK_CLASSES[name] = cls\n    return cls\n\n\ndef get_posenet_class(name):\n    global REGISTERED_NETWORK_CLASSES\n    assert name in REGISTERED_NETWORK_CLASSES, f\"available class: {REGISTERED_NETWORK_CLASSES}\"\n    return REGISTERED_NETWORK_CLASSES[name]\n\n\n\n\n@register_posenet\nclass RNNPose(nn.Module):\n    def __init__(self,\n                 criterions,\n                 opt,\n                 name=\"RNNPose\",\n                 **kwargs):\n        super().__init__()\n\n        self.name = name\n        self.opt = opt\n\n        self.hybrid_desc_net=HybridDescNet(opt.descriptor_net)\n        self.ctx_fea_net = ContextFeatureNet(opt.descriptor_net)\n\n        self.ctx_fea = {} \n\n        self.render_params = edict({\n            \"width\": opt.input_w,  # 128,\n            \"height\": opt.input_h,  # 128,\n            \"gpu_id\": opt.get('gpu_id', 0),\n            \"obj_seqs\": opt.obj_seqs\n        })\n\n        renderer, diff_renderer= self._render_init(self.render_params)\n        self.diff_renderer = diff_renderer\n\n        self.motion_net = PoseRefiner(\n            opt.motion_net, bn_is_training=self.training, is_training=self.training,\n            renderer=diff_renderer \n        )\n        self.contrastive_loss = criterions.get(\n            \"metric_loss\", None)\n        self.pose_loss = criterions.get(\"pose_loss\", None)\n\n        self.register_buffer(\"global_step\", torch.LongTensor(1).zero_())\n\n\n    def update_global_step(self):\n        self.global_step += 1\n\n    def get_global_step(self):\n        return int(self.global_step.cpu().numpy()[0])\n\n    def clear_global_step(self):\n        self.global_step.zero_()\n    \n    def sample_poses(self, pose_tgt):\n        SYN_STD_ROTATION = 15\n        SYN_STD_TRANSLATION = 0.01\n        ANGLE_MAX=45\n        pose_src = pose_tgt.copy()\n        num = pose_tgt.shape[0]\n        for i in range(num):\n            euler = mat2euler(pose_tgt[i, :3, :3])\n            euler += SYN_STD_ROTATION * np.random.randn(3) * np.pi / 180.0\n            pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])\n\n            pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)\n            pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)\n            pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)\n\n            r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)/np.pi*180\n\n            while r_dist > ANGLE_MAX:#or not (16 < center_x < (640 - 16) and 16 < center_y < (480 - 16)):\n                print(\"r_dist > ANGLE_MAX, resampling...\")\n                euler = mat2euler(pose_tgt[i, :3, :3])\n                euler += SYN_STD_ROTATION * np.random.randn(3) * np.pi / 180.0\n                pose_src[i, :3, :3] = euler2mat(euler[0], euler[1], euler[2])\n\n                pose_src[i, 0, 3] = pose_tgt[i, 0, 3]+ SYN_STD_TRANSLATION * np.random.randn(1)\n                pose_src[i, 1, 3] = pose_tgt[i, 1, 3] + SYN_STD_TRANSLATION * np.random.randn(1)\n                pose_src[i, 2, 3] = pose_tgt[i, 2, 3]  + 5 * SYN_STD_TRANSLATION * np.random.randn(1)\n\n                r_dist = np.arccos((np.trace(pose_src[i, :3,:3].transpose(-1,-2) @ pose_tgt[i, :3,:3]) - 1 )/2)*np.pi/180\n        return pose_src\n\n    def _render_init(self, config):\n        # from data.ycb.basic import bop_ycb_class2idx\n        print(\"config.gpu_id:\", config.gpu_id)\n\n        obj_paths = []\n        tex_paths = []\n\n        # build cls2idx table for the renderer\n        cls2idx = {}\n        LM_SEQ=[\"ape\", \"benchvise\", \"camera\",\"cam\", \"can\", \"cat\", \"driller\", \"duck\", \"eggbox\", \"glue\", \"holepuncher\", \"iron\", \"lamp\", \"phone\"]\n        # YCB_SEQ=bop_ycb_class2idx.keys()\n        for i, seq in enumerate(set(config.obj_seqs)):\n\n            if seq in LM_SEQ: \n                obj_path = f'{os.path.dirname(__file__)}/../EXPDATA/LM6d_converted/models/{seq}/textured.obj'\n                tex_path = f'{os.path.dirname(__file__)}/../EXPDATA/LM6d_converted/models/{seq}/texture_map.png'\n                assert os.path.exists(obj_path), f\"'{obj_path}' dose not exist!\" \n                assert os.path.exists(tex_path), f\"'{tex_path}' dose not exist!\" \n                obj_paths.append(obj_path)\n                tex_paths.append(tex_path)\n                cls2idx[seq] = i\n            else:\n                raise NotImplementedError\n        renderer=None\n\n        diff_renderer = DiffRendererWrapper(obj_paths)\n        diff_renderer.cls2idx = cls2idx\n\n        return renderer, diff_renderer\n\n\n    def forward(self, sample):\n        assert len(set(sample['class_name']))==1, \"A batch should contain data of the same class.\"\n        class_name = sample['class_name'][0]\n\n        #encode 3d-2d descriptors\n        preds_dict=self.hybrid_desc_net(sample)\n\n        len_src_f = sample['stack_lengths'][0][0]\n        geofea_3d = preds_dict.get('descriptors_3d', None) \n        geofea_2d = preds_dict.get('descriptors_2d', None) \n\n\n        #encode 3D context features \n        if self.training:\n            self.ctx_fea[class_name]=self.ctx_fea_net(sample)\n        else:\n            if class_name not in self.ctx_fea:\n                self.ctx_fea[class_name]=self.ctx_fea_net(sample)\n\n        preds_dict.update(self.ctx_fea[class_name])\n        ctx_fea_3d = preds_dict['ctx_fea_3d'][:len_src_f]\n\n        if self.training:\n            pose=pose_padding(sample['original_RT'])\n\n            if sample.get(\"rendered_RT\", None) is not None:\n                syn_pose = pose_padding(sample['rendered_RT'])\n            else:\n                syn_pose = torch.tensor(self.sample_poses(sample['original_RT'].detach().cpu(\n                ).numpy()), device=sample['original_RT'].device, dtype=sample['original_RT'].dtype)\n                syn_pose = pose_padding(syn_pose)\n        else:\n            pose = pose_padding(sample['original_RT'])\n            syn_pose = pose_padding(sample['rendered_RT'])\n            \n\n        # calculate the GT relative pose and the initial pose\n        \n        Ts_pred = SE3Sequence(\n            matrix=torch.stack([syn_pose ], dim=1))\n        mot_res = self.motion_net(\n            Ts=Ts_pred,  \n            intrinsics=sample['K'],\n            image=sample['image'], \n            fea_3d=ctx_fea_3d[None], \n            Tj_gt=SE3Sequence(matrix=pose[:, None]),\n            obj_cls=sample['class_name'],\n            geofea_2d = geofea_2d, \n            geofea_3d = geofea_3d[None]\n        )\n        preds_dict.update(mot_res)\n        sample['syn_depth'] = mot_res['syn_depth']\n        if self.training:\n            ret = self.loss(sample, preds_dict)\n\n            ret['syn_img'] = mot_res['syn_img']\n            ret['syn_depth'] = mot_res['syn_depth']\n            ret['flow'] = mot_res['flow']\n            ret['weight'] = mot_res['weight']\n\n            return ret\n        else:\n            # ret = self.loss(sample, preds_dict)\n            ret={}\n            ret.update(preds_dict)\n            return ret\n\n    \n    def loss(self, sample, preds_dict):\n\n        len_src_f = sample['stack_lengths'][0][0]\n        RT =sample['RT']\n        camera_intrinsic=sample['K']\n        descriptors_2d_map = preds_dict['descriptors_2d']\n        descriptors_3d = preds_dict['descriptors_3d'][:len_src_f]\n        rand_descriptors_3d = preds_dict['descriptors_3d'][len_src_f:]\n        model_points=sample['model_points'][0][:len_src_f]\n        orig_model_points = sample['original_model_points']\n        rand_model_points=sample['model_points'][0][len_src_f:]\n        lifted_points=sample['lifted_points'][0].squeeze(0)\n        correspondence=sample['correspondences_2d3d'].squeeze(0)\n        depth = sample['depth']\n        \n\n        # get the foreground 2d descriptors \n        ys_, xs_=torch.nonzero(sample['depth'].squeeze(), as_tuple=True )\n        descriptors_2d=descriptors_2d_map[:,:,ys_,xs_].squeeze().permute([1,0])\n\n        if self.training: \n            device= lifted_points.device\n            fg_point_num = len(lifted_points)\n            model_point_num = len(model_points)\n          \n            # append bg features \n            ys_bg, xs_bg = torch.nonzero(sample['depth'].squeeze()<=0, as_tuple=True )\n            descriptors_2d_bg = descriptors_2d_map[:,:,ys_bg, xs_bg].squeeze().permute([1,0])\n            # descriptors_2d_bg = descriptors_2d_bg[np.random.randint(0,len(descriptors_2d_bg), size=len(lifted_points) )]\n            descriptors_2d = torch.cat([descriptors_2d, descriptors_2d_bg], dim=0) \n            descriptors_3d = torch.cat([descriptors_3d, descriptors_2d_bg], dim=0 )\n\n            # append handcrafted coordinates to simplify the code(assign the same coordinates for the bg points far away from the fg points, i.e. 10e6 )\n            lifted_points = torch.cat([lifted_points, torch.ones([len(descriptors_2d_bg) ,3],device=device )*10e6 ] ) #append very distant points\n            model_points = torch.cat([model_points, torch.ones([len(descriptors_2d_bg) ,3], device=device)*10e6 ] ) #append very distant points\n\n            #append one-to-one inds\n            if len(descriptors_2d_bg)>0:\n                # randomly sample the bg points to balance the learning process\n                sample_inds=np.random.randint(0,len(descriptors_2d_bg), size=int(len(correspondence)*0.1) )\n           \n                bg_corr= torch.stack([\n                    torch.arange(fg_point_num, fg_point_num+len(descriptors_2d_bg), device=device )[sample_inds], \n                    torch.arange(model_point_num, model_point_num+len(descriptors_2d_bg), device=device)[sample_inds]\n                 ], dim=-1) #Nx2\n                correspondence = torch.cat([correspondence, bg_corr ], dim=0)\n            \n        if len(lifted_points)>0:\n            contra_loss=self.contrastive_loss(src_pcd=lifted_points, tgt_pcd=model_points,\n                            src_feats=descriptors_2d, tgt_feats=descriptors_3d,\n                            correspondence=correspondence,\n                            scores_overlap=None, scores_saliency=None)\n        else:\n            print(\"Warning: Contrastive loss is skipped, as no lifted point is found!\", flush=True)\n            contra_loss={\n                'circle_loss': torch.zeros([1], device=depth.device),\n                'recall': torch.zeros([1], device=depth.device)\n            }\n\n\n        loss3d = partial(self.pose_loss.forward,\n                         points=orig_model_points[None])\n        motion_loss = self.motion_net.compute_loss(\n            preds_dict['Tij_gt'], sample['syn_depth'], intrinsics=sample['K'], loss='l1', log_error=True, loss3d=loss3d)\n\n        \n        loss =  self.contrastive_loss.weight * contra_loss['circle_loss']  + motion_loss['total_loss'] \n\n        res = {\n            \"loss\": loss, \n            \"circle_loss\": contra_loss['circle_loss'].detach(),\n            \"recall\":contra_loss['recall'].detach(),\n            # \"geometric_loss\": torch.zeros(1),\n            \"reproj_loss\": motion_loss['reproj_loss'].detach(),\n            \"loss_3d_proj\": motion_loss['loss_3d_proj'].detach(), \n            \"valid_mask\": motion_loss[\"valid_mask\"].detach(),\n        }\n        return res\n\n"
  },
  {
    "path": "model/descriptor2D.py",
    "content": "from easydict import EasyDict as edict\nfrom pathlib import Path\nimport torch\nfrom torch import nn\nfrom torchplus.nn.modules.common import Empty\n\n\n\nclass SuperPoint2D(nn.Module):\n    \"\"\"SuperPoint Convolutional Detector and Descriptor\n\n    SuperPoint: Self-Supervised Interest Point Detection and\n    Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew\n    Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629\n\n    \"\"\"\n    default_config = {\n        'descriptor_dim': 256,\n        'nms_radius': 4,\n        'keypoint_threshold': 0.005,\n        'max_keypoints': -1,\n        'remove_borders': 4,\n        'saliency_score_normalization_fuc': 'sigmoid',\n        \"use_instance_norm\": True\n    }\n\n    def __init__(self, config):\n        super().__init__()\n        self.default_config.update(config) \n        self.config=edict(self.default_config)\n\n        self.normalize_output=config.get('normalize_output', True)\n\n        self.saliency_score_normalization_fuc= self.config.saliency_score_normalization_fuc\n        assert self.saliency_score_normalization_fuc in ['sigmoid', 'softmax']\n\n        self.relu = nn.ReLU(inplace=True)\n        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n        \n        if self.config.use_instance_norm:\n            self.Normalization=nn.InstanceNorm2d\n        else:\n            self.Normalization = Empty\n\n\n        c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256\n        self.input_dim=config.input_dim\n        # self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)\n        self.conv1a = nn.Conv2d(config.input_dim, c1, kernel_size=3, stride=1, padding=1)\n        self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)\n        self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)\n        self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)\n        self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)\n        self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)\n        self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)\n        self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)\n\n        self.convPa = nn.Sequential(\n            nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1), \n            self.Normalization(c5)\n        )\n\n        self.convPb = nn.Conv2d(c5, 1, kernel_size=1, stride=1, padding=0)\n\n        self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)\n\n        self.convDb = nn.Conv2d(\n            c5, self.config['descriptor_dim'],\n            kernel_size=1, stride=1, padding=0)\n\n        self.decode1=nn.Sequential(\n            nn.Upsample(scale_factor=2,mode='bilinear'),\n            nn.Conv2d(c4,c4, kernel_size=3, stride=1, padding=1),\n            self.Normalization(c4),\n            nn.ReLU())\n        self.decode2=nn.Sequential(\n            nn.Upsample(scale_factor=2,mode='bilinear'),\n            nn.Conv2d(c4+c3,c4, kernel_size=3, stride=1, padding=1),\n            self.Normalization(c4),\n            nn.ReLU()\n            )\n        \n        self.decode3=nn.Sequential(\n            nn.Upsample(scale_factor=2,mode='bilinear'),\n            nn.Conv2d(c4+c2,c4, kernel_size=3, stride=1, padding=1),\n            self.Normalization(c4),\n            nn.ReLU()\n            )\n        \n        path = Path(__file__).parent.parent/ 'weights/superpoint_v1.pth'\n\n        self.load_state_dict(torch.load(str(path), map_location='cpu' ), strict=False)\n\n        mk = self.config['max_keypoints']\n        if mk == 0 or mk < -1:\n            raise ValueError('\\\"max_keypoints\\\" must be positive or \\\"-1\\\"')\n\n        print('Loaded SuperPoint model')\n\n    def load_state_dict(self,state_dict, strict=True):\n \n        if not strict:\n            updated_state_dict = {}\n            model_dict = self.state_dict()\n            for k, v in state_dict.items():\n                if k in model_dict and v.shape == model_dict[k].shape:\n                    updated_state_dict[k] = v\n        else:\n            updated_state_dict = state_dict\n        super().load_state_dict(updated_state_dict, strict)\n\n\n    def forward_encoder(self, x):\n        if self.input_dim==1:\n            x=x.mean(dim=1, keepdims=True) #\n        x_skip=[]\n        # Shared Encoder\n        x = self.relu(self.conv1a(x))\n        x = self.relu(self.conv1b(x))\n        x_skip.append(x)\n        x = self.pool(x)\n        x = self.relu(self.conv2a(x))\n        x = self.relu(self.conv2b(x))\n        x_skip.append(x)\n        x = self.pool(x)\n        x = self.relu(self.conv3a(x))\n        x = self.relu(self.conv3b(x))\n        x_skip.append(x)\n        x = self.pool(x)\n        x = self.relu(self.conv4a(x))\n        x = self.relu(self.conv4b(x))\n\n        return x, x_skip\n    def forward_decoder(self, x, x_skip):\n        #upsample first\n        x = self.decode1(x)\n        x=torch.cat([x, x_skip[-1]], dim=1)\n        x = self.decode2(x)\n        x=torch.cat([x, x_skip[-2]], dim=1)\n        x = self.decode3(x)\n\n\n        # Compute the dense keypoint scores\n        cPa = self.relu(self.convPa(x))\n        scores = self.convPb(cPa)\n\n        if self.saliency_score_normalization_fuc == 'sigmoid':\n            scores = nn.functional.sigmoid(scores)\n        elif self.saliency_score_normalization_fuc == 'softmax':\n            scores_shape = scores.shape\n            scores = scores.reshape(*scores.shape[0:2], -1)\n            scores = nn.functional.softmax(scores/1, dim=-1)\n            scores= scores.reshape(*scores_shape)#.clone()\n        else:\n            raise ValueError\n\n        keypoints=None\n\n        # Compute the dense descriptors\n        cDa = self.relu(self.convDa(x))\n        descriptors = self.convDb(cDa)\n        if self.normalize_output:\n            descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)\n\n        return keypoints, scores, descriptors\n    def forward(self, data):\n        \"\"\" Compute keypoints, scores, descriptors for image \"\"\"\n        # Shared Encoder\n        x, x_skip=self.forward_encoder(data)\n\n        # Compute the dense keypoint scores\n        keypoints, scores, descriptors = self.forward_decoder(x, x_skip)\n\n        return {\n            'keypoints': keypoints,\n            'scores': scores,\n            'descriptors': descriptors,\n        }\n\n"
  },
  {
    "path": "model/descriptor3D.py",
    "content": "import torch \nimport torch.nn as nn \n\n\nfrom kpconv.kpconv_blocks import *\nimport torch.nn.functional as F\nimport numpy as np\nfrom kpconv.lib.utils import square_distance\n\nclass KPSuperpoint3Dv2(nn.Module):\n    #remove useless channels\n    def __init__(self, config):\n        super().__init__()\n        self.normalize_output=config.get('normalize_output', True)\n\n        # build the architectures\n        config.architecture = [\n        'simple',\n        'resnetb',\n        ]\n        for i in range(config.num_layers-1):\n            config.architecture.append('resnetb_strided')\n            config.architecture.append('resnetb')\n            config.architecture.append('resnetb')\n        for i in range(config.num_layers-2):\n            config.architecture.append('nearest_upsample')\n            config.architecture.append('unary')\n        config.architecture.append('nearest_upsample')\n        config.architecture.append('last_unary')\n\n        ############\n        # Parameters\n        ############\n        # Current radius of convolution and feature dimension\n        layer = 0\n        r = config.first_subsampling_dl * config.conv_radius\n        in_dim = config.in_features_dim\n        out_dim = config.first_feats_dim\n        self.K = config.num_kernel_points\n        self.epsilon = torch.nn.Parameter(torch.tensor(-5.0))\n        self.final_feats_dim = config.final_feats_dim\n\n        #####################\n        # List Encoder blocks\n        #####################\n        # Save all block operations in a list of modules\n        self.encoder_blocks = nn.ModuleList()\n        self.encoder_skip_dims = []\n        self.encoder_skips = []\n\n        # Loop over consecutive blocks\n        for block_i, block in enumerate(config.architecture):\n\n            # Check equivariance\n            if ('equivariant' in block) and (not out_dim % 3 == 0):\n                raise ValueError('Equivariant block but features dimension is not a factor of 3')\n\n            # Detect change to next layer for skip connection\n            if np.any([tmp in block for tmp in ['pool', 'strided', 'upsample', 'global']]):\n                self.encoder_skips.append(block_i)\n                self.encoder_skip_dims.append(in_dim)\n\n            # Detect upsampling block to stop\n            if 'upsample' in block:\n                break\n\n            # Apply the good block function defining tf ops\n            self.encoder_blocks.append(block_decider(block,\n                                                    r,\n                                                    in_dim,\n                                                    out_dim,\n                                                    layer,\n                                                    config))\n\n            # Update dimension of input from output\n            if 'simple' in block:\n                in_dim = out_dim // 2\n            else:\n                in_dim = out_dim\n\n            # Detect change to a subsampled layer\n            if 'pool' in block or 'strided' in block:\n                # Update radius and feature dimension for next layer\n                layer += 1\n                r *= 2\n                out_dim *= 2\n\n        #####################\n        # bottleneck layer \n        #####################\n        botneck_feats_dim = config.gnn_feats_dim\n        self.bottle = nn.Conv1d(in_dim, botneck_feats_dim,kernel_size=1,bias=True)\n        # num_head = config.num_head\n        self.proj_gnn = nn.Conv1d(botneck_feats_dim,botneck_feats_dim,kernel_size=1, bias=True)\n\n        \n        #####################\n        # List Decoder blocks\n        #####################\n        out_dim = botneck_feats_dim # + 2\n\n        # Save all block operations in a list of modules\n        self.decoder_blocks = nn.ModuleList()\n        self.decoder_concats = []\n\n        # Find first upsampling block\n        start_i = 0\n        for block_i, block in enumerate(config.architecture):\n            if 'upsample' in block:\n                start_i = block_i\n                break\n        \n        # Loop over consecutive blocks\n        for block_i, block in enumerate(config.architecture[start_i:]):\n\n            # Add dimension of skip connection concat\n            if block_i > 0 and 'upsample' in config.architecture[start_i + block_i - 1]:\n                in_dim += self.encoder_skip_dims[layer]\n                self.decoder_concats.append(block_i)\n\n            # Apply the good block function defining tf ops\n            self.decoder_blocks.append(block_decider(block,\n                                                    r,\n                                                    in_dim,\n                                                    out_dim,\n                                                    layer,\n                                                    config))\n\n                \n\n            # Update dimension of input from output\n            in_dim = out_dim\n\n            # Detect change to a subsampled layer\n            if 'upsample' in block:\n                # Update radius and feature dimension for next layer\n                layer -= 1\n                r *= 0.5\n                out_dim = out_dim // 2\n        return\n\n    def regular_score(self,score):\n        score = torch.where(torch.isnan(score), torch.zeros_like(score), score)\n        score = torch.where(torch.isinf(score), torch.zeros_like(score), score)\n        return score\n\n    def forward_encoder(self, batch):\n        # Get input features\n        x = batch['features'].clone().detach()\n        len_src_c = batch['stack_lengths'][-1][0]\n        len_src_f = batch['stack_lengths'][0][0]\n        pcd_c = batch['points'][-1]\n        pcd_f = batch['points'][0]\n        src_pcd_c, tgt_pcd_c = pcd_c[:len_src_c], pcd_c[len_src_c:]\n\n        sigmoid = nn.Sigmoid()\n        #################################\n        # 1. encoder \n        skip_x = []\n        for block_i, block_op in enumerate(self.encoder_blocks):\n            if block_i in self.encoder_skips:\n                skip_x.append(x)\n            x = block_op(x, batch)\n\n        #################################\n        # 2. project the bottleneck features\n        feats_c = x.transpose(0,1).unsqueeze(0)  #[1, C, N]\n        feats_c = self.bottle(feats_c)  #[1, C, N]\n        \n        return feats_c,skip_x\n\n    def forward_middle(self, x):\n\n        feats_c = self.proj_gnn(x)   \n        feats_gnn_raw = feats_c.squeeze(0).transpose(0,1)\n\n        return feats_gnn_raw\n\n\n    def forward_decoder(self, x, skip_x, batch ):\n        sigmoid = nn.Sigmoid()\n        for block_i, block_op in enumerate(self.decoder_blocks):\n            if block_i in self.decoder_concats:\n                x = torch.cat([x, skip_x.pop()], dim=1)\n            x = block_op(x, batch)\n            \n        feats_f = x[:,:self.final_feats_dim]\n\n        # normalise point-wise features\n        if self.normalize_output:\n            feats_f = F.normalize(feats_f, p=2, dim=1)\n\n        return feats_f \n\n    def forward(self, batch):\n        # Get input features\n        feats_c,skip_x = self.forward_encoder(batch)\n        x = self.forward_middle(feats_c)\n        feats_f = self.forward_decoder(x, skip_x, batch)\n\n        return feats_f\n\n\n\n\n"
  },
  {
    "path": "model/losses.py",
    "content": "from sklearn.metrics import precision_recall_fscore_support\nfrom thirdparty.kpconv.lib.utils import square_distance\nfrom abc import ABCMeta, abstractmethod\nimport time\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\nimport torchplus\n# from utils.pca import pca_tch\n\nimport kornia\n# import self_voxelo.utils.pose_utils as pose_utils\nimport apex.amp as amp\n\n# class Loss(object):\n\n\nclass Loss(nn.Module):\n    \"\"\"Abstract base class for loss functions.\"\"\"\n    __metaclass__ = ABCMeta\n\n    def __init__(self, loss_weight=1):\n        super(Loss, self).__init__()\n        self._loss_weight = loss_weight\n\n    # def __call__(self,\n    def forward(self,\n                prediction_tensor,\n                target_tensor,\n                ignore_nan_targets=False,\n                scope=None,\n                **params):\n        \"\"\"Call the loss function.\n\n        Args:\n          prediction_tensor: an N-d tensor of shape [batch, anchors, ...]\n            representing predicted quantities.\n          target_tensor: an N-d tensor of shape [batch, anchors, ...] representing\n            regression or classification targets.\n          ignore_nan_targets: whether to ignore nan targets in the loss computation.\n            E.g. can be used if the target tensor is missing groundtruth data that\n            shouldn't be factored into the loss.\n          scope: Op scope name. Defaults to 'Loss' if None.\n          **params: Additional keyword arguments for specific implementations of\n                  the Loss.\n\n        Returns:\n          loss: a tensor representing the value of the loss function.\n        \"\"\"\n        if ignore_nan_targets:\n            target_tensor = torch.where(torch.isnan(target_tensor),\n                                        prediction_tensor,\n                                        target_tensor)\n        # ret = self._compute_loss(prediction_tensor, target_tensor, **params)\n        # if isinstance(ret, (list, tuple)):\n        #     return [self._loss_weight*ret[0]] + list(ret[1:])\n        # else:\n        return self._loss_weight*self._compute_loss(prediction_tensor, target_tensor, **params)\n\n    @abstractmethod\n    @amp.float_function\n    def _compute_loss(self, prediction_tensor, target_tensor, **params):\n        \"\"\"Method to be overridden by implementations.\n\n        Args:\n          prediction_tensor: a tensor representing predicted quantities\n          target_tensor: a tensor representing regression or classification targets\n          **params: Additional keyword arguments for specific implementations of\n                  the Loss.\n\n        Returns:\n          loss: an N-d tensor of shape [batch, anchors, ...] containing the loss per\n            anchor\n        \"\"\"\n\n        raise NotImplementedError\n\n\nclass L2Loss(Loss):\n\n    def __init__(self, loss_weight=1):\n        super(L2Loss, self).__init__(loss_weight)\n\n    def _compute_loss(self, prediction_tensor, target_tensor, mask=None):\n        \"\"\"Compute loss function.\n\n        Args:\n          prediction_tensor: A float tensor of shape [batch_size, num_anchors,\n            code_size] representing the (encoded) predicted locations of objects.\n          target_tensor: A float tensor of shape [batch_size, num_anchors,\n            code_size] representing the regression targets\n\n        Returns:\n          loss: a float tensor of shape [batch_size, num_anchors] tensor\n            representing the value of the loss function.\n        \"\"\"\n        diff = prediction_tensor - target_tensor\n\n        if mask is not None:\n            mask = mask.expand_as(diff).byte()\n            diff = diff[mask]\n        # square_diff = 0.5 * weighted_diff * weighted_diff\n        square_diff = diff * diff\n        return square_diff.mean()\n\n\nclass AdaptiveWeightedL2Loss(Loss):\n\n    def __init__(self, init_alpha, learn_alpha=True, loss_weight=1, focal_gamma=0):\n        super(AdaptiveWeightedL2Loss, self).__init__(loss_weight)\n        self.learn_alpha = learn_alpha\n        self.alpha = nn.Parameter(torch.Tensor(\n            [init_alpha]), requires_grad=learn_alpha)\n        self.focal_gamma = focal_gamma\n        # self.alpha_shift = -13  # -10# TODO: temporarily test\n\n    def _compute_loss(self, prediction_tensor, target_tensor, mask=None, alpha=None, focal_gamma=None):\n        \"\"\"Compute loss function.\n\n        Args:\n          prediction_tensor: A float tensor of shape [batch_size, num_anchors,\n            code_size] representing the (encoded) predicted locations of objects.\n          target_tensor: A float tensor of shape [batch_size, num_anchors,\n            code_size] representing the regression targets\n\n        Returns:\n          loss: a float tensor of shape [batch_size, num_anchors] tensor\n            representing the value of the loss function.\n        \"\"\"\n\n        if focal_gamma is None:\n            focal_gamma = self.focal_gamma\n        _alpha = self.alpha\n        if mask is None:\n            mask = torch.ones_like(target_tensor)\n        else:\n            mask = mask.expand_as(target_tensor)\n\n        diff = prediction_tensor - target_tensor\n        square_diff = (diff * diff) * mask\n\n        # loss = square_diff.mean()\n        input_shape = prediction_tensor.shape\n        loss = torch.sum(square_diff, dim=list(range(1, len(input_shape)))) / \\\n            (torch.sum(mask, dim=list(range(1, len(input_shape)))) + 1e-12)  # (B,)\n\n        focal_weight = (torch.exp(-_alpha) * loss)**focal_gamma\n        focal_weight = focal_weight/(torch.sum(focal_weight) + 1e-12)\n\n        loss = focal_weight*(torch.exp(-_alpha) * loss)\n        loss = loss.sum() + _alpha\n        return loss\n\n\nclass MetricLoss(nn.Module):\n    \"\"\"\n    We evaluate both contrastive loss and circle loss\n    \"\"\"\n\n    def __init__(self, configs, log_scale=16, pos_optimal=0.1, neg_optimal=1.4, ):\n        super(MetricLoss, self).__init__()\n        self.log_scale = log_scale\n        self.pos_optimal = pos_optimal\n        self.neg_optimal = neg_optimal\n\n        self.pos_margin = configs.pos_margin\n        self.neg_margin = configs.neg_margin\n        self.max_points = configs.max_points\n\n        self.safe_radius = configs.safe_radius\n        self.matchability_radius = configs.matchability_radius\n        # just to take care of the numeric precision\n        self.pos_radius = configs.pos_radius + 0.001\n        self.weight = configs.get('loss_weight', 1)\n\n    def get_circle_loss(self, coords_dist, feats_dist):\n        \"\"\"\n        Modified from: https://github.com/XuyangBai/D3Feat.pytorch\n        \"\"\"\n        pos_mask = coords_dist < self.pos_radius\n        neg_mask = coords_dist > self.safe_radius\n\n        # get anchors that have both positive and negative pairs\n        row_sel = ((pos_mask.sum(-1) > 0) * (neg_mask.sum(-1) > 0)).detach()\n        col_sel = ((pos_mask.sum(-2) > 0) * (neg_mask.sum(-2) > 0)).detach()\n\n        # get alpha for both positive and negative pairs\n        pos_weight = feats_dist - 1e5 * \\\n            (~pos_mask).float()  # mask the non-positive\n        # mask the uninformative positive\n        pos_weight = (pos_weight - self.pos_optimal)\n        pos_weight = torch.max(torch.zeros_like(\n            pos_weight), pos_weight).detach()\n\n        neg_weight = feats_dist + 1e5 * \\\n            (~neg_mask).float()  # mask the non-negative\n        # mask the uninformative negative\n        neg_weight = (self.neg_optimal - neg_weight)\n        neg_weight = torch.max(torch.zeros_like(\n            neg_weight), neg_weight).detach()\n\n        lse_pos_row = torch.logsumexp(\n            self.log_scale * (feats_dist - self.pos_margin) * pos_weight, dim=-1)\n        lse_pos_col = torch.logsumexp(\n            self.log_scale * (feats_dist - self.pos_margin) * pos_weight, dim=-2)\n\n        lse_neg_row = torch.logsumexp(\n            self.log_scale * (self.neg_margin - feats_dist) * neg_weight, dim=-1)\n        lse_neg_col = torch.logsumexp(\n            self.log_scale * (self.neg_margin - feats_dist) * neg_weight, dim=-2)\n\n        loss_row = F.softplus(lse_pos_row + lse_neg_row)/self.log_scale\n        loss_col = F.softplus(lse_pos_col + lse_neg_col)/self.log_scale\n\n        circle_loss = (loss_row[row_sel].mean() + loss_col[col_sel].mean()) / 2\n\n        return circle_loss\n\n    def get_recall(self, coords_dist, feats_dist):\n        \"\"\"\n        Get feature match recall, divided by number of true inliers\n        \"\"\"\n        pos_mask = coords_dist < self.pos_radius\n        n_gt_pos = (pos_mask.sum(-1) > 0).float().sum()+1e-12\n        _, sel_idx = torch.min(feats_dist, -1)\n\n        sel_dist = torch.gather(coords_dist, dim=-1,\n                                index=sel_idx[:, None])[pos_mask.sum(-1) > 0]\n        n_pred_pos = (sel_dist < self.pos_radius).float().sum()\n        recall = n_pred_pos / n_gt_pos\n        return recall\n\n    def get_weighted_bce_loss(self, prediction, gt):\n        loss = nn.BCELoss(reduction='none')\n\n        class_loss = loss(prediction, gt)\n\n        weights = torch.ones_like(gt)\n        w_negative = gt.sum()/gt.size(0)\n        w_positive = 1 - w_negative\n\n        weights[gt >= 0.5] = w_positive\n        weights[gt < 0.5] = w_negative\n        w_class_loss = torch.mean(weights * class_loss)\n\n        #######################################\n        # get classification precision and recall\n        predicted_labels = prediction.detach().cpu().round().numpy()\n        cls_precision, cls_recall, _, _ = precision_recall_fscore_support(\n            gt.cpu().numpy(), predicted_labels, average='binary')\n\n        return w_class_loss, cls_precision, cls_recall\n\n    def forward(self, src_pcd, tgt_pcd, src_feats, tgt_feats, correspondence, scores_overlap, scores_saliency, rot=None, trans=None):\n        \"\"\"\n        Circle loss for metric learning, here we feed the positive pairs only\n        Input:\n            src_pcd:        [N, 3], pcd of the 3d model \n            tgt_pcd:        [M, 3], pcd of the lifted model from 2d depth\n            rot:            [3, 3], rotation used to rotate the src_pcd to the current frame\n            trans:          [3, 1], translation used to translate the src_pcd to the current frame\n            src_feats:      [N, C]\n            tgt_feats:      [M, C]\n        \"\"\"\n\n        if rot is not None and trans is not None:\n            src_pcd = (torch.matmul(rot, src_pcd.transpose(0, 1)) +\n                       trans).transpose(0, 1)\n\n        stats = dict()\n\n        #######################################\n        # filter some of correspondence\n        if(correspondence.size(0) > self.max_points):\n            choice = np.random.permutation(correspondence.size(0))[\n                :self.max_points]\n            correspondence = correspondence[choice]\n\n        src_idx = correspondence[:, 0]\n        tgt_idx = correspondence[:, 1]\n        src_pcd, tgt_pcd = src_pcd[src_idx], tgt_pcd[tgt_idx]\n        src_feats, tgt_feats = src_feats[src_idx], tgt_feats[tgt_idx]\n\n        #######################\n        # get L2 distance between source / target point cloud\n\n        coords_dist = torch.sqrt(square_distance(\n            src_pcd[None, :, :], tgt_pcd[None, :, :]).squeeze(0))\n        feats_dist = torch.sqrt(square_distance(\n            src_feats[None, :, :], tgt_feats[None, :, :], normalised=True)).squeeze(0)\n\n        ##############################\n        # get FMR and circle loss\n        ##############################\n        recall = self.get_recall(coords_dist, feats_dist)\n        circle_loss = self.get_circle_loss(coords_dist, feats_dist)\n\n        stats['circle_loss'] = circle_loss\n        stats['recall'] = recall\n\n        return stats\n\n\nclass PointAlignmentLoss(nn.Module):\n    def __init__(self, loss_weight=1, ):\n        super().__init__()\n        self._loss_weight = loss_weight\n\n    def forward(self, R_pred, t_pred, R_tgt, t_tgt, points):\n        return self._loss_weight*self._compute_loss(R_pred, t_pred, R_tgt, t_tgt, points)\n\n    def _compute_loss(self, R_pred, t_pred, R_tgt, t_tgt, points, ):\n        \"\"\"[summary]\n\n        Args:\n            R_pred ([type]): [Bx3x3]\n            t_pred ([type]): [Bx3]\n            R_tgt ([type]): [Bx3x3]\n            t_tgt ([type]): [Bx3]\n            points ([type]): [BxNx3]\n\n        Returns:\n            loss [type]: [loss value]\n        \"\"\"\n\n        loss = 0\n        for b in range(len(points)):\n\n            diff = points[b]@R_pred[b].transpose(-1, -2) + t_pred[b] - (\n                points[b]@R_tgt[b].transpose(-1, -2)+t_tgt[b])\n\n            square_diff = diff.abs()  \n            loss += torch.mean(square_diff)*3\n\n        # loss/=len(points)\n\n        return loss\n"
  },
  {
    "path": "scripts/compile_3rdparty.sh",
    "content": "#!/usr/bin/bash\n\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\n\ncd $SCRIPT_DIR/../thirdparty/kpconv/cpp_wrappers\nbash compile_wrappers.sh\n\ncd $SCRIPT_DIR/../thirdparty/nn\npython setup.py build_ext --inplace"
  },
  {
    "path": "scripts/eval.sh",
    "content": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH\"\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH/thirdparty:$PYTHONPATH\"\nexport model_dir='outputs'\nseq=cat\ngpu=1\nstart_gpu_id=0\nmkdir $model_dir\n\ntrain_file=$PROJECT_ROOT_PATH/tools/eval.py\nconfig_path=/$PROJECT_ROOT_PATH/config/linemod/\"$seq\"_fw0.5.yml\npretrain=$PROJECT_ROOT_PATH/weights/trained_models/\"$seq\".tckpt\n\npython -u $train_file multi_proc_train  \\\n        --config_path $config_path \\\n        --model_dir $model_dir/results \\\n        --use_dist True \\\n        --dist_port 10000 \\\n        --gpus_per_node $gpu \\\n        --optim_eval True \\\n        --use_apex False \\\n        --world_size $gpu \\\n        --start_gpu_id $start_gpu_id \\\n        --pretrained_path $pretrain \n "
  },
  {
    "path": "scripts/eval_lmocc.sh",
    "content": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH\"\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH/thirdparty:$PYTHONPATH\"\nexport model_dir='outputs'\nseq=ape\ngpu=1\nstart_gpu_id=0\nmkdir $model_dir\n\ntrain_file=$PROJECT_ROOT_PATH/tools/eval.py\nconfig_path=/$PROJECT_ROOT_PATH/config/linemod/\"$seq\"_fw0.5_occ.yml\npretrain=$PROJECT_ROOT_PATH/weights/trained_models/\"$seq\".tckpt\n\npython -u $train_file multi_proc_train  \\\n        --config_path $config_path \\\n        --model_dir $model_dir/results \\\n        --use_dist True \\\n        --dist_port 10000 \\\n        --gpus_per_node $gpu \\\n        --optim_eval True \\\n        --use_apex False \\\n        --world_size $gpu \\\n        --start_gpu_id $start_gpu_id \\\n        --pretrained_path $pretrain \n        # --use_apex True \\\n "
  },
  {
    "path": "scripts/run_dataformatter.sh",
    "content": "SCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nPROJ_ROOT=$SCRIPT_DIR/../\n\npython $PROJ_ROOT/tools/transform_data_format.py run --data_type \"LM_FUSE_PVNET\"  --data_info_path $PROJ_ROOT/EXPDATA/\"/data_info/linemod_all_10k_default.info.all\"  --image_root $PROJ_ROOT/EXPDATA/raw_data/fuse --depth_root $PROJ_ROOT/EXPDATA/raw_data/orig_renders --save_dir $PROJ_ROOT/EXPDATA/LINEMOD/fuse_formatted  && touch 3.finished \n\n\n"
  },
  {
    "path": "scripts/run_datainfo_generation.sh",
    "content": "\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nPROJ_ROOT=$SCRIPT_DIR/../\nexport PYTHONPATH=$PROJ_ROOT:PYTHONPATH\n\nmkdir --parent $PROJ_ROOT/EXPDATA/data_info/deepim/\n\npython $PROJ_ROOT/tools/generate_data_info_deepim_0_orig.py  create_data_info  --data_root $PROJ_ROOT/EXPDATA/LM6d_converted/LM6d_refine  --saving_path $PROJ_ROOT/EXPDATA/data_info/deepim/linemod_orig_deepim.info  \n\npython $PROJ_ROOT/tools/generate_data_info_deepim_1_syn.py  create_data_info  --data_root $PROJ_ROOT/EXPDATA/LM6d_converted/LM6d_refine_syn  --saving_path $PROJ_ROOT/EXPDATA/data_info/deepim/linemod_syn_deepim.info  --with_assertion True\n\npython $PROJ_ROOT/tools/generate_data_info_deepim_2_posecnnval.py  create_data_info  --data_root $PROJ_ROOT/EXPDATA/LM6d_converted/LM6d_refine  --saving_path $PROJ_ROOT/EXPDATA/data_info/linemod_posecnn.info  --with_assertion True\n\npython $PROJ_ROOT/tools/generate_data_info_v2_deepim.py  create_data_info  --data_root $PROJ_ROOT/EXPDATA/LINEMOD/fuse_formatted/  --saving_path $PROJ_ROOT/EXPDATA/data_info/linemod_fusesformatted_all10k_deepim.info  --training_data_ratio 1 --shuffle=False\n"
  },
  {
    "path": "scripts/train.sh",
    "content": "export PROJECT_ROOT_PATH=/home/RNNPose/Projects/Works/RNNPose_release\n\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH:$PYTHONPATH\"\nexport PYTHONPATH=\"$PROJECT_ROOT_PATH/thirdparty:$PYTHONPATH\"\nexport model_dir='outputs'\nseq=cat\ngpu=1\nstart_gpu_id=0\nmkdir $model_dir\n\ntrain_file=$PROJECT_ROOT_PATH/tools/train.py\nconfig_path=/$PROJECT_ROOT_PATH/config/linemod/\"$seq\"_fw0.5.yml\n# pretrain=$PROJECT_ROOT_PATH/weights/trained_models/\"$seq\".tckpt\n\npython -u $train_file multi_proc_train  \\\n        --config_path $config_path \\\n        --model_dir $model_dir/results \\\n        --use_dist True \\\n        --dist_port 10000 \\\n        --gpus_per_node $gpu \\\n        --optim_eval True \\\n        --use_apex False \\\n        --world_size $gpu \\\n        --start_gpu_id $start_gpu_id \\\n        # --pretrained_path $pretrain \n "
  },
  {
    "path": "thirdparty/kpconv/__init__.py",
    "content": ""
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/compile_wrappers.sh",
    "content": "#!/bin/bash\n\n# Compile cpp subsampling\ncd cpp_subsampling\npython3 setup.py build_ext --inplace\ncd ..\n\n# Compile cpp neighbors\ncd cpp_neighbors\npython3 setup.py build_ext --inplace\ncd .."
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/build.bat",
    "content": "@echo off\npy setup.py build_ext --inplace\n\n\npause"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp",
    "content": "\n#include \"neighbors.h\"\n\n\nvoid brute_neighbors(vector<PointXYZ>& queries, vector<PointXYZ>& supports, vector<int>& neighbors_indices, float radius, int verbose)\n{\n\n\t// Initialize variables\n\t// ******************\n\n\t// square radius\n\tfloat r2 = radius * radius;\n\n\t// indices\n\tint i0 = 0;\n\n\t// Counting vector\n\tint max_count = 0;\n\tvector<vector<int>> tmp(queries.size());\n\n\t// Search neigbors indices\n\t// ***********************\n\n\tfor (auto& p0 : queries)\n\t{\n\t\tint i = 0;\n\t\tfor (auto& p : supports)\n\t\t{\n\t\t\tif ((p0 - p).sq_norm() < r2)\n\t\t\t{\n\t\t\t\ttmp[i0].push_back(i);\n\t\t\t\tif (tmp[i0].size() > max_count)\n\t\t\t\t\tmax_count = tmp[i0].size();\n\t\t\t}\n\t\t\ti++;\n\t\t}\n\t\ti0++;\n\t}\n\n\t// Reserve the memory\n\tneighbors_indices.resize(queries.size() * max_count);\n\ti0 = 0;\n\tfor (auto& inds : tmp)\n\t{\n\t\tfor (int j = 0; j < max_count; j++)\n\t\t{\n\t\t\tif (j < inds.size())\n\t\t\t\tneighbors_indices[i0 * max_count + j] = inds[j];\n\t\t\telse\n\t\t\t\tneighbors_indices[i0 * max_count + j] = -1;\n\t\t}\n\t\ti0++;\n\t}\n\n\treturn;\n}\n\nvoid ordered_neighbors(vector<PointXYZ>& queries,\n                        vector<PointXYZ>& supports,\n                        vector<int>& neighbors_indices,\n                        float radius)\n{\n\n\t// Initialize variables\n\t// ******************\n\n\t// square radius\n\tfloat r2 = radius * radius;\n\n\t// indices\n\tint i0 = 0;\n\n\t// Counting vector\n\tint max_count = 0;\n\tfloat d2;\n\tvector<vector<int>> tmp(queries.size());\n\tvector<vector<float>> dists(queries.size());\n\n\t// Search neigbors indices\n\t// ***********************\n\n\tfor (auto& p0 : queries)\n\t{\n\t\tint i = 0;\n\t\tfor (auto& p : supports)\n\t\t{\n\t\t    d2 = (p0 - p).sq_norm();\n\t\t\tif (d2 < r2)\n\t\t\t{\n\t\t\t    // Find order of the new point\n\t\t\t    auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2);\n\t\t\t    int index = std::distance(dists[i0].begin(), it);\n\n\t\t\t    // Insert element\n                dists[i0].insert(it, d2);\n                tmp[i0].insert(tmp[i0].begin() + index, i);\n\n\t\t\t    // Update max count\n\t\t\t\tif (tmp[i0].size() > max_count)\n\t\t\t\t\tmax_count = tmp[i0].size();\n\t\t\t}\n\t\t\ti++;\n\t\t}\n\t\ti0++;\n\t}\n\n\t// Reserve the memory\n\tneighbors_indices.resize(queries.size() * max_count);\n\ti0 = 0;\n\tfor (auto& inds : tmp)\n\t{\n\t\tfor (int j = 0; j < max_count; j++)\n\t\t{\n\t\t\tif (j < inds.size())\n\t\t\t\tneighbors_indices[i0 * max_count + j] = inds[j];\n\t\t\telse\n\t\t\t\tneighbors_indices[i0 * max_count + j] = -1;\n\t\t}\n\t\ti0++;\n\t}\n\n\treturn;\n}\n\nvoid batch_ordered_neighbors(vector<PointXYZ>& queries,\n                                vector<PointXYZ>& supports,\n                                vector<int>& q_batches,\n                                vector<int>& s_batches,\n                                vector<int>& neighbors_indices,\n                                float radius)\n{\n\n\t// Initialize variables\n\t// ******************\n\n\t// square radius\n\tfloat r2 = radius * radius;\n\n\t// indices\n\tint i0 = 0;\n\n\t// Counting vector\n\tint max_count = 0;\n\tfloat d2;\n\tvector<vector<int>> tmp(queries.size());\n\tvector<vector<float>> dists(queries.size());\n\n\t// batch index\n\tint b = 0;\n\tint sum_qb = 0;\n\tint sum_sb = 0;\n\n\n\t// Search neigbors indices\n\t// ***********************\n\n\tfor (auto& p0 : queries)\n\t{\n\t    // Check if we changed batch\n\t    if (i0 == sum_qb + q_batches[b])\n\t    {\n\t        sum_qb += q_batches[b];\n\t        sum_sb += s_batches[b];\n\t        b++;\n\t    }\n\n\t    // Loop only over the supports of current batch\n\t    vector<PointXYZ>::iterator p_it;\n\t\tint i = 0;\n        for(p_it = supports.begin() + sum_sb; p_it < supports.begin() + sum_sb + s_batches[b]; p_it++ )\n        {\n\t\t    d2 = (p0 - *p_it).sq_norm();\n\t\t\tif (d2 < r2)\n\t\t\t{\n\t\t\t    // Find order of the new point\n\t\t\t    auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2);\n\t\t\t    int index = std::distance(dists[i0].begin(), it);\n\n\t\t\t    // Insert element\n                dists[i0].insert(it, d2);\n                tmp[i0].insert(tmp[i0].begin() + index, sum_sb + i);\n\n\t\t\t    // Update max count\n\t\t\t\tif (tmp[i0].size() > max_count)\n\t\t\t\t\tmax_count = tmp[i0].size();\n\t\t\t}\n\t\t\ti++;\n\t\t}\n\t\ti0++;\n\t}\n\n\t// Reserve the memory\n\tneighbors_indices.resize(queries.size() * max_count);\n\ti0 = 0;\n\tfor (auto& inds : tmp)\n\t{\n\t\tfor (int j = 0; j < max_count; j++)\n\t\t{\n\t\t\tif (j < inds.size())\n\t\t\t\tneighbors_indices[i0 * max_count + j] = inds[j];\n\t\t\telse\n\t\t\t\tneighbors_indices[i0 * max_count + j] = supports.size();\n\t\t}\n\t\ti0++;\n\t}\n\n\treturn;\n}\n\n\nvoid batch_nanoflann_neighbors(vector<PointXYZ>& queries,\n                                vector<PointXYZ>& supports,\n                                vector<int>& q_batches,\n                                vector<int>& s_batches,\n                                vector<int>& neighbors_indices,\n                                float radius)\n{\n\n\t// Initialize variables\n\t// ******************\n\n\t// indices\n\tint i0 = 0;\n\n\t// Square radius\n\tfloat r2 = radius * radius;\n\n\t// Counting vector\n\tint max_count = 0;\n\tfloat d2;\n\tvector<vector<pair<size_t, float>>> all_inds_dists(queries.size());\n\n\t// batch index\n\tint b = 0;\n\tint sum_qb = 0;\n\tint sum_sb = 0;\n\n\t// Nanoflann related variables\n\t// ***************************\n\n\t// CLoud variable\n\tPointCloud current_cloud;\n\n\t// Tree parameters\n\tnanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);\n\n\t// KDTree type definition\n    typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Simple_Adaptor<float, PointCloud > ,\n                                                        PointCloud,\n                                                        3 > my_kd_tree_t;\n\n    // Pointer to trees\n    my_kd_tree_t* index;\n\n    // Build KDTree for the first batch element\n    current_cloud.pts = vector<PointXYZ>(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]);\n    index = new my_kd_tree_t(3, current_cloud, tree_params);\n    index->buildIndex();\n\n\n\t// Search neigbors indices\n\t// ***********************\n\n    // Search params\n    nanoflann::SearchParams search_params;\n    search_params.sorted = true;\n\n\tfor (auto& p0 : queries)\n\t{\n\n\t    // Check if we changed batch\n\t    if (i0 == sum_qb + q_batches[b])\n\t    {\n\t        sum_qb += q_batches[b];\n\t        sum_sb += s_batches[b];\n\t        b++;\n\n\t        // Change the points\n\t        current_cloud.pts.clear();\n            current_cloud.pts = vector<PointXYZ>(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]);\n\n\t        // Build KDTree of the current element of the batch\n            delete index;\n            index = new my_kd_tree_t(3, current_cloud, tree_params);\n            index->buildIndex();\n\t    }\n\n\t    // Initial guess of neighbors size\n        all_inds_dists[i0].reserve(max_count);\n\n\t    // Find neighbors\n\t    float query_pt[3] = { p0.x, p0.y, p0.z};\n\t\tsize_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);\n\n        // Update max count\n        if (nMatches > max_count)\n            max_count = nMatches;\n\n        // Increment query idx\n\t\ti0++;\n\t}\n\n\t// Reserve the memory\n\tneighbors_indices.resize(queries.size() * max_count);\n\ti0 = 0;\n\tsum_sb = 0;\n\tsum_qb = 0;\n\tb = 0;\n\tfor (auto& inds_dists : all_inds_dists)\n\t{\n\t    // Check if we changed batch\n\t    if (i0 == sum_qb + q_batches[b])\n\t    {\n\t        sum_qb += q_batches[b];\n\t        sum_sb += s_batches[b];\n\t        b++;\n\t    }\n\n\t\tfor (int j = 0; j < max_count; j++)\n\t\t{\n\t\t\tif (j < inds_dists.size())\n\t\t\t\tneighbors_indices[i0 * max_count + j] = inds_dists[j].first + sum_sb;\n\t\t\telse\n\t\t\t\tneighbors_indices[i0 * max_count + j] = supports.size();\n\t\t}\n\t\ti0++;\n\t}\n\n\tdelete index;\n\n\treturn;\n}\n\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.h",
    "content": "\n\n#include \"../../cpp_utils/cloud/cloud.h\"\n#include \"../../cpp_utils/nanoflann/nanoflann.hpp\"\n\n#include <set>\n#include <cstdint>\n\nusing namespace std;\n\n\nvoid ordered_neighbors(vector<PointXYZ>& queries,\n                        vector<PointXYZ>& supports,\n                        vector<int>& neighbors_indices,\n                        float radius);\n\nvoid batch_ordered_neighbors(vector<PointXYZ>& queries,\n                                vector<PointXYZ>& supports,\n                                vector<int>& q_batches,\n                                vector<int>& s_batches,\n                                vector<int>& neighbors_indices,\n                                float radius);\n\nvoid batch_nanoflann_neighbors(vector<PointXYZ>& queries,\n                                vector<PointXYZ>& supports,\n                                vector<int>& q_batches,\n                                vector<int>& s_batches,\n                                vector<int>& neighbors_indices,\n                                float radius);\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/setup.py",
    "content": "from distutils.core import setup, Extension\nimport numpy.distutils.misc_util\n\n# Adding OpenCV to project\n# ************************\n\n# Adding sources of the project\n# *****************************\n\nSOURCES = [\"../cpp_utils/cloud/cloud.cpp\",\n             \"neighbors/neighbors.cpp\",\n             \"wrapper.cpp\"]\n\nmodule = Extension(name=\"radius_neighbors\",\n                    sources=SOURCES,\n                    extra_compile_args=['-std=c++11',\n                                        '-D_GLIBCXX_USE_CXX11_ABI=0'])\n\n\nsetup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs())\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_neighbors/wrapper.cpp",
    "content": "#include <Python.h>\n#include <numpy/arrayobject.h>\n#include \"neighbors/neighbors.h\"\n#include <string>\n\n\n\n// docstrings for our module\n// *************************\n\nstatic char module_docstring[] = \"This module provides two methods to compute radius neighbors from pointclouds or batch of pointclouds\";\n\nstatic char batch_query_docstring[] = \"Method to get radius neighbors in a batch of stacked pointclouds\";\n\n\n// Declare the functions\n// *********************\n\nstatic PyObject *batch_neighbors(PyObject *self, PyObject *args, PyObject *keywds);\n\n\n// Specify the members of the module\n// *********************************\n\nstatic PyMethodDef module_methods[] = \n{\n\t{ \"batch_query\", (PyCFunction)batch_neighbors, METH_VARARGS | METH_KEYWORDS, batch_query_docstring },\n\t{NULL, NULL, 0, NULL}\n};\n\n\n// Initialize the module\n// *********************\n\nstatic struct PyModuleDef moduledef = \n{\n    PyModuleDef_HEAD_INIT,\n    \"radius_neighbors\",\t\t// m_name\n    module_docstring,       // m_doc\n    -1,                     // m_size\n    module_methods,         // m_methods\n    NULL,                   // m_reload\n    NULL,                   // m_traverse\n    NULL,                   // m_clear\n    NULL,                   // m_free\n};\n\nPyMODINIT_FUNC PyInit_radius_neighbors(void)\n{\n    import_array();\n\treturn PyModule_Create(&moduledef);\n}\n\n\n// Definition of the batch_subsample method\n// **********************************\n\nstatic PyObject* batch_neighbors(PyObject* self, PyObject* args, PyObject* keywds)\n{\n\n\t// Manage inputs\n\t// *************\n\n\t// Args containers\n\tPyObject* queries_obj = NULL;\n\tPyObject* supports_obj = NULL;\n\tPyObject* q_batches_obj = NULL;\n\tPyObject* s_batches_obj = NULL;\n\n\t// Keywords containers\n\tstatic char* kwlist[] = { \"queries\", \"supports\", \"q_batches\", \"s_batches\", \"radius\", NULL };\n\tfloat radius = 0.1;\n\n\t// Parse the input  \n\tif (!PyArg_ParseTupleAndKeywords(args, keywds, \"OOOO|$f\", kwlist, &queries_obj, &supports_obj, &q_batches_obj, &s_batches_obj, &radius))\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error parsing arguments\");\n\t\treturn NULL;\n\t}\n\n\n\t// Interpret the input objects as numpy arrays.\n\tPyObject* queries_array = PyArray_FROM_OTF(queries_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tPyObject* supports_array = PyArray_FROM_OTF(supports_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tPyObject* q_batches_array = PyArray_FROM_OTF(q_batches_obj, NPY_INT, NPY_IN_ARRAY);\n\tPyObject* s_batches_array = PyArray_FROM_OTF(s_batches_obj, NPY_INT, NPY_IN_ARRAY);\n\n\t// Verify data was load correctly.\n\tif (queries_array == NULL)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting query points to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (supports_array == NULL)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting support points to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (q_batches_array == NULL)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting query batches to numpy arrays of type int32\");\n\t\treturn NULL;\n\t}\n\tif (s_batches_array == NULL)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting support batches to numpy arrays of type int32\");\n\t\treturn NULL;\n\t}\n\n\t// Check that the input array respect the dims\n\tif ((int)PyArray_NDIM(queries_array) != 2 || (int)PyArray_DIM(queries_array, 1) != 3)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : query.shape is not (N, 3)\");\n\t\treturn NULL;\n\t}\n\tif ((int)PyArray_NDIM(supports_array) != 2 || (int)PyArray_DIM(supports_array, 1) != 3)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : support.shape is not (N, 3)\");\n\t\treturn NULL;\n\t}\n\tif ((int)PyArray_NDIM(q_batches_array) > 1)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : queries_batches.shape is not (B,) \");\n\t\treturn NULL;\n\t}\n\tif ((int)PyArray_NDIM(s_batches_array) > 1)\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : supports_batches.shape is not (B,) \");\n\t\treturn NULL;\n\t}\n\tif ((int)PyArray_DIM(q_batches_array, 0) != (int)PyArray_DIM(s_batches_array, 0))\n\t{\n\t\tPy_XDECREF(queries_array);\n\t\tPy_XDECREF(supports_array);\n\t\tPy_XDECREF(q_batches_array);\n\t\tPy_XDECREF(s_batches_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong number of batch elements: different for queries and supports \");\n\t\treturn NULL;\n\t}\n\n\t// Number of points\n\tint Nq = (int)PyArray_DIM(queries_array, 0);\n\tint Ns= (int)PyArray_DIM(supports_array, 0);\n\n\t// Number of batches\n\tint Nb = (int)PyArray_DIM(q_batches_array, 0);\n\n\t// Call the C++ function\n\t// *********************\n\n\t// Convert PyArray to Cloud C++ class\n\tvector<PointXYZ> queries;\n\tvector<PointXYZ> supports;\n\tvector<int> q_batches;\n\tvector<int> s_batches;\n\tqueries = vector<PointXYZ>((PointXYZ*)PyArray_DATA(queries_array), (PointXYZ*)PyArray_DATA(queries_array) + Nq);\n\tsupports = vector<PointXYZ>((PointXYZ*)PyArray_DATA(supports_array), (PointXYZ*)PyArray_DATA(supports_array) + Ns);\n\tq_batches = vector<int>((int*)PyArray_DATA(q_batches_array), (int*)PyArray_DATA(q_batches_array) + Nb);\n\ts_batches = vector<int>((int*)PyArray_DATA(s_batches_array), (int*)PyArray_DATA(s_batches_array) + Nb);\n\n\t// Create result containers\n\tvector<int> neighbors_indices;\n\n\t// Compute results\n\t//batch_ordered_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius);\n\tbatch_nanoflann_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius);\n\n\t// Check result\n\tif (neighbors_indices.size() < 1)\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error\");\n\t\treturn NULL;\n\t}\n\n\t// Manage outputs\n\t// **************\n\n\t// Maximal number of neighbors\n\tint max_neighbors = neighbors_indices.size() / Nq;\n\n\t// Dimension of output containers\n\tnpy_intp* neighbors_dims = new npy_intp[2];\n\tneighbors_dims[0] = Nq;\n\tneighbors_dims[1] = max_neighbors;\n\n\t// Create output array\n\tPyObject* res_obj = PyArray_SimpleNew(2, neighbors_dims, NPY_INT);\n\tPyObject* ret = NULL;\n\n\t// Fill output array with values\n\tsize_t size_in_bytes = Nq * max_neighbors * sizeof(int);\n\tmemcpy(PyArray_DATA(res_obj), neighbors_indices.data(), size_in_bytes);\n\n\t// Merge results\n\tret = Py_BuildValue(\"N\", res_obj);\n\n\t// Clean up\n\t// ********\n\n\tPy_XDECREF(queries_array);\n\tPy_XDECREF(supports_array);\n\tPy_XDECREF(q_batches_array);\n\tPy_XDECREF(s_batches_array);\n\n\treturn ret;\n}\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/build.bat",
    "content": "@echo off\npy setup.py build_ext --inplace\n\n\npause"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp",
    "content": "\n#include \"grid_subsampling.h\"\n\n\nvoid grid_subsampling(vector<PointXYZ>& original_points,\n                      vector<PointXYZ>& subsampled_points,\n                      vector<float>& original_features,\n                      vector<float>& subsampled_features,\n                      vector<int>& original_classes,\n                      vector<int>& subsampled_classes,\n                      float sampleDl,\n                      int verbose) {\n\n\t// Initialize variables\n\t// ******************\n\n\t// Number of points in the cloud\n\tsize_t N = original_points.size();\n\n\t// Dimension of the features\n\tsize_t fdim = original_features.size() / N;\n\tsize_t ldim = original_classes.size() / N;\n\n\t// Limits of the cloud\n\tPointXYZ minCorner = min_point(original_points);\n\tPointXYZ maxCorner = max_point(original_points);\n\tPointXYZ originCorner = floor(minCorner * (1/sampleDl)) * sampleDl;\n\n\t// Dimensions of the grid\n\tsize_t sampleNX = (size_t)floor((maxCorner.x - originCorner.x) / sampleDl) + 1;\n\tsize_t sampleNY = (size_t)floor((maxCorner.y - originCorner.y) / sampleDl) + 1;\n\t//size_t sampleNZ = (size_t)floor((maxCorner.z - originCorner.z) / sampleDl) + 1;\n\n\t// Check if features and classes need to be processed\n\tbool use_feature = original_features.size() > 0;\n\tbool use_classes = original_classes.size() > 0;\n\n\n\t// Create the sampled map\n\t// **********************\n\n\t// Verbose parameters\n\tint i = 0;\n\tint nDisp = N / 100;\n\n\t// Initialize variables\n\tsize_t iX, iY, iZ, mapIdx;\n\tunordered_map<size_t, SampledData> data;\n\n\tfor (auto& p : original_points)\n\t{\n\t\t// Position of point in sample map\n\t\tiX = (size_t)floor((p.x - originCorner.x) / sampleDl);\n\t\tiY = (size_t)floor((p.y - originCorner.y) / sampleDl);\n\t\tiZ = (size_t)floor((p.z - originCorner.z) / sampleDl);\n\t\tmapIdx = iX + sampleNX*iY + sampleNX*sampleNY*iZ;\n\n\t\t// If not already created, create key\n\t\tif (data.count(mapIdx) < 1)\n\t\t\tdata.emplace(mapIdx, SampledData(fdim, ldim));\n\n\t\t// Fill the sample map\n\t\tif (use_feature && use_classes)\n\t\t\tdata[mapIdx].update_all(p, original_features.begin() + i * fdim, original_classes.begin() + i * ldim);\n\t\telse if (use_feature)\n\t\t\tdata[mapIdx].update_features(p, original_features.begin() + i * fdim);\n\t\telse if (use_classes)\n\t\t\tdata[mapIdx].update_classes(p, original_classes.begin() + i * ldim);\n\t\telse\n\t\t\tdata[mapIdx].update_points(p);\n\n\t\t// Display\n\t\ti++;\n\t\tif (verbose > 1 && i%nDisp == 0)\n\t\t\tstd::cout << \"\\rSampled Map : \" << std::setw(3) << i / nDisp << \"%\";\n\n\t}\n\n\t// Divide for barycentre and transfer to a vector\n\tsubsampled_points.reserve(data.size());\n\tif (use_feature)\n\t\tsubsampled_features.reserve(data.size() * fdim);\n\tif (use_classes)\n\t\tsubsampled_classes.reserve(data.size() * ldim);\n\tfor (auto& v : data)\n\t{\n\t\tsubsampled_points.push_back(v.second.point * (1.0 / v.second.count));\n\t\tif (use_feature)\n\t\t{\n\t\t    float count = (float)v.second.count;\n\t\t    transform(v.second.features.begin(),\n                      v.second.features.end(),\n                      v.second.features.begin(),\n                      [count](float f) { return f / count;});\n            subsampled_features.insert(subsampled_features.end(),v.second.features.begin(),v.second.features.end());\n\t\t}\n\t\tif (use_classes)\n\t\t{\n\t\t    for (int i = 0; i < ldim; i++)\n\t\t        subsampled_classes.push_back(max_element(v.second.labels[i].begin(), v.second.labels[i].end(),\n\t\t        [](const pair<int, int>&a, const pair<int, int>&b){return a.second < b.second;})->first);\n\t\t}\n\t}\n\n\treturn;\n}\n\n\nvoid batch_grid_subsampling(vector<PointXYZ>& original_points,\n                              vector<PointXYZ>& subsampled_points,\n                              vector<float>& original_features,\n                              vector<float>& subsampled_features,\n                              vector<int>& original_classes,\n                              vector<int>& subsampled_classes,\n                              vector<int>& original_batches,\n                              vector<int>& subsampled_batches,\n                              float sampleDl,\n                              int max_p)\n{\n\t// Initialize variables\n\t// ******************\n\n\tint b = 0;\n\tint sum_b = 0;\n\n\t// Number of points in the cloud\n\tsize_t N = original_points.size();\n\n\t// Dimension of the features\n\tsize_t fdim = original_features.size() / N;\n\tsize_t ldim = original_classes.size() / N;\n\n\t// Handle max_p = 0\n\tif (max_p < 1)\n\t    max_p = N;\n\n\t// Loop over batches\n\t// *****************\n\n\tfor (b = 0; b < original_batches.size(); b++)\n\t{\n\n\t    // Extract batch points features and labels\n\t    vector<PointXYZ> b_o_points = vector<PointXYZ>(original_points.begin () + sum_b,\n\t                                                   original_points.begin () + sum_b + original_batches[b]);\n\n        vector<float> b_o_features;\n        if (original_features.size() > 0)\n        {\n            b_o_features = vector<float>(original_features.begin () + sum_b * fdim,\n                                         original_features.begin () + (sum_b + original_batches[b]) * fdim);\n\t    }\n\n\t    vector<int> b_o_classes;\n        if (original_classes.size() > 0)\n        {\n            b_o_classes = vector<int>(original_classes.begin () + sum_b * ldim,\n                                      original_classes.begin () + sum_b + original_batches[b] * ldim);\n\t    }\n\n\n        // Create result containers\n        vector<PointXYZ> b_s_points;\n        vector<float> b_s_features;\n        vector<int> b_s_classes;\n\n        // Compute subsampling on current batch\n        grid_subsampling(b_o_points,\n                         b_s_points,\n                         b_o_features,\n                         b_s_features,\n                         b_o_classes,\n                         b_s_classes,\n                         sampleDl,\n\t\t\t\t\t\t 0);\n\n        // Stack batches points features and labels\n        // ****************************************\n\n        // If too many points remove some\n        if (b_s_points.size() <= max_p)\n        {\n            subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.end());\n\n            if (original_features.size() > 0)\n                subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.end());\n\n            if (original_classes.size() > 0)\n                subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.end());\n\n            subsampled_batches.push_back(b_s_points.size());\n        }\n        else\n        {\n            subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.begin() + max_p);\n\n            if (original_features.size() > 0)\n                subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.begin() + max_p * fdim);\n\n            if (original_classes.size() > 0)\n                subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.begin() + max_p * ldim);\n\n            subsampled_batches.push_back(max_p);\n        }\n\n        // Stack new batch lengths\n        sum_b += original_batches[b];\n\t}\n\n\treturn;\n}\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h",
    "content": "\n\n#include \"../../cpp_utils/cloud/cloud.h\"\n\n#include <set>\n#include <cstdint>\n\nusing namespace std;\n\nclass SampledData\n{\npublic:\n\n\t// Elements\n\t// ********\n\n\tint count;\n\tPointXYZ point;\n\tvector<float> features;\n\tvector<unordered_map<int, int>> labels;\n\n\n\t// Methods\n\t// *******\n\n\t// Constructor\n\tSampledData() \n\t{ \n\t\tcount = 0; \n\t\tpoint = PointXYZ();\n\t}\n\n\tSampledData(const size_t fdim, const size_t ldim)\n\t{\n\t\tcount = 0;\n\t\tpoint = PointXYZ();\n\t    features = vector<float>(fdim);\n\t    labels = vector<unordered_map<int, int>>(ldim);\n\t}\n\n\t// Method Update\n\tvoid update_all(const PointXYZ p, vector<float>::iterator f_begin, vector<int>::iterator l_begin)\n\t{\n\t\tcount += 1;\n\t\tpoint += p;\n\t\ttransform (features.begin(), features.end(), f_begin, features.begin(), plus<float>());\n\t\tint i = 0;\n\t\tfor(vector<int>::iterator it = l_begin; it != l_begin + labels.size(); ++it)\n\t\t{\n\t\t    labels[i][*it] += 1;\n\t\t    i++;\n\t\t}\n\t\treturn;\n\t}\n\tvoid update_features(const PointXYZ p, vector<float>::iterator f_begin)\n\t{\n\t\tcount += 1;\n\t\tpoint += p;\n\t\ttransform (features.begin(), features.end(), f_begin, features.begin(), plus<float>());\n\t\treturn;\n\t}\n\tvoid update_classes(const PointXYZ p, vector<int>::iterator l_begin)\n\t{\n\t\tcount += 1;\n\t\tpoint += p;\n\t\tint i = 0;\n\t\tfor(vector<int>::iterator it = l_begin; it != l_begin + labels.size(); ++it)\n\t\t{\n\t\t    labels[i][*it] += 1;\n\t\t    i++;\n\t\t}\n\t\treturn;\n\t}\n\tvoid update_points(const PointXYZ p)\n\t{\n\t\tcount += 1;\n\t\tpoint += p;\n\t\treturn;\n\t}\n};\n\nvoid grid_subsampling(vector<PointXYZ>& original_points,\n                      vector<PointXYZ>& subsampled_points,\n                      vector<float>& original_features,\n                      vector<float>& subsampled_features,\n                      vector<int>& original_classes,\n                      vector<int>& subsampled_classes,\n                      float sampleDl,\n                      int verbose);\n\nvoid batch_grid_subsampling(vector<PointXYZ>& original_points,\n                            vector<PointXYZ>& subsampled_points,\n                            vector<float>& original_features,\n                            vector<float>& subsampled_features,\n                            vector<int>& original_classes,\n                            vector<int>& subsampled_classes,\n                            vector<int>& original_batches,\n                            vector<int>& subsampled_batches,\n                            float sampleDl,\n                            int max_p);\n\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/setup.py",
    "content": "from distutils.core import setup, Extension\nimport numpy.distutils.misc_util\n\n# Adding OpenCV to project\n# ************************\n\n# Adding sources of the project\n# *****************************\n\nSOURCES = [\"../cpp_utils/cloud/cloud.cpp\",\n             \"grid_subsampling/grid_subsampling.cpp\",\n             \"wrapper.cpp\"]\n\nmodule = Extension(name=\"grid_subsampling\",\n                    sources=SOURCES,\n                    extra_compile_args=['-std=c++11',\n                                        '-D_GLIBCXX_USE_CXX11_ABI=0'])\n\n\nsetup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs())\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_subsampling/wrapper.cpp",
    "content": "#include <Python.h>\n#include <numpy/arrayobject.h>\n#include \"grid_subsampling/grid_subsampling.h\"\n#include <string>\n\n\n\n// docstrings for our module\n// *************************\n\nstatic char module_docstring[] = \"This module provides an interface for the subsampling of a batch of stacked pointclouds\";\n\nstatic char subsample_docstring[] = \"function subsampling a pointcloud\";\n\nstatic char subsample_batch_docstring[] = \"function subsampling a batch of stacked pointclouds\";\n\n\n// Declare the functions\n// *********************\n\nstatic PyObject *cloud_subsampling(PyObject* self, PyObject* args, PyObject* keywds);\nstatic PyObject *batch_subsampling(PyObject *self, PyObject *args, PyObject *keywds);\n\n\n// Specify the members of the module\n// *********************************\n\nstatic PyMethodDef module_methods[] = \n{\n\t{ \"subsample\", (PyCFunction)cloud_subsampling, METH_VARARGS | METH_KEYWORDS, subsample_docstring },\n\t{ \"subsample_batch\", (PyCFunction)batch_subsampling, METH_VARARGS | METH_KEYWORDS, subsample_batch_docstring },\n\t{NULL, NULL, 0, NULL}\n};\n\n\n// Initialize the module\n// *********************\n\nstatic struct PyModuleDef moduledef = \n{\n    PyModuleDef_HEAD_INIT,\n    \"grid_subsampling\",     // m_name\n    module_docstring,       // m_doc\n    -1,                     // m_size\n    module_methods,         // m_methods\n    NULL,                   // m_reload\n    NULL,                   // m_traverse\n    NULL,                   // m_clear\n    NULL,                   // m_free\n};\n\nPyMODINIT_FUNC PyInit_grid_subsampling(void)\n{\n    import_array();\n\treturn PyModule_Create(&moduledef);\n}\n\n\n// Definition of the batch_subsample method\n// **********************************\n\nstatic PyObject* batch_subsampling(PyObject* self, PyObject* args, PyObject* keywds)\n{\n\n\t// Manage inputs\n\t// *************\n\n\t// Args containers\n\tPyObject* points_obj = NULL;\n\tPyObject* features_obj = NULL;\n\tPyObject* classes_obj = NULL;\n\tPyObject* batches_obj = NULL;\n\n\t// Keywords containers\n\tstatic char* kwlist[] = { \"points\", \"batches\", \"features\", \"classes\", \"sampleDl\", \"method\", \"max_p\", \"verbose\", NULL };\n\tfloat sampleDl = 0.1;\n\tconst char* method_buffer = \"barycenters\";\n\tint verbose = 0;\n\tint max_p = 0;\n\n\t// Parse the input  \n\tif (!PyArg_ParseTupleAndKeywords(args, keywds, \"OO|$OOfsii\", kwlist, &points_obj, &batches_obj, &features_obj, &classes_obj, &sampleDl, &method_buffer, &max_p, &verbose))\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error parsing arguments\");\n\t\treturn NULL;\n\t}\n\n\t// Get the method argument\n\tstring method(method_buffer);\n\n\t// Interpret method\n\tif (method.compare(\"barycenters\") && method.compare(\"voxelcenters\"))\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error parsing method. Valid method names are \\\"barycenters\\\" and \\\"voxelcenters\\\" \");\n\t\treturn NULL;\n\t}\n\n\t// Check if using features or classes\n\tbool use_feature = true, use_classes = true;\n\tif (features_obj == NULL)\n\t\tuse_feature = false;\n\tif (classes_obj == NULL)\n\t\tuse_classes = false;\n\n\t// Interpret the input objects as numpy arrays.\n\tPyObject* points_array = PyArray_FROM_OTF(points_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tPyObject* batches_array = PyArray_FROM_OTF(batches_obj, NPY_INT, NPY_IN_ARRAY);\n\tPyObject* features_array = NULL;\n\tPyObject* classes_array = NULL;\n\tif (use_feature)\n\t\tfeatures_array = PyArray_FROM_OTF(features_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tif (use_classes)\n\t\tclasses_array = PyArray_FROM_OTF(classes_obj, NPY_INT, NPY_IN_ARRAY);\n\n\t// Verify data was load correctly.\n\tif (points_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input points to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (batches_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input batches to numpy arrays of type int32\");\n\t\treturn NULL;\n\t}\n\tif (use_feature && features_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input features to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (use_classes && classes_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input classes to numpy arrays of type int32\");\n\t\treturn NULL;\n\t}\n\n\t// Check that the input array respect the dims\n\tif ((int)PyArray_NDIM(points_array) != 2 || (int)PyArray_DIM(points_array, 1) != 3)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : points.shape is not (N, 3)\");\n\t\treturn NULL;\n\t}\n\tif ((int)PyArray_NDIM(batches_array) > 1)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : batches.shape is not (B,) \");\n\t\treturn NULL;\n\t}\n\tif (use_feature && ((int)PyArray_NDIM(features_array) != 2))\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : features.shape is not (N, d)\");\n\t\treturn NULL;\n\t}\n\n\tif (use_classes && (int)PyArray_NDIM(classes_array) > 2)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : classes.shape is not (N,) or (N, d)\");\n\t\treturn NULL;\n\t}\n\n\t// Number of points\n\tint N = (int)PyArray_DIM(points_array, 0);\n\n\t// Number of batches\n\tint Nb = (int)PyArray_DIM(batches_array, 0);\n\n\t// Dimension of the features\n\tint fdim = 0;\n\tif (use_feature)\n\t\tfdim = (int)PyArray_DIM(features_array, 1);\n\n\t//Dimension of labels\n\tint ldim = 1;\n\tif (use_classes && (int)PyArray_NDIM(classes_array) == 2)\n\t\tldim = (int)PyArray_DIM(classes_array, 1);\n\n\t// Check that the input array respect the number of points\n\tif (use_feature && (int)PyArray_DIM(features_array, 0) != N)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : features.shape is not (N, d)\");\n\t\treturn NULL;\n\t}\n\tif (use_classes && (int)PyArray_DIM(classes_array, 0) != N)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(batches_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : classes.shape is not (N,) or (N, d)\");\n\t\treturn NULL;\n\t}\n\n\n\t// Call the C++ function\n\t// *********************\n\n\t// Create pyramid\n\tif (verbose > 0)\n\t\tcout << \"Computing cloud pyramid with support points: \" << endl;\n\n\n\t// Convert PyArray to Cloud C++ class\n\tvector<PointXYZ> original_points;\n\tvector<int> original_batches;\n\tvector<float> original_features;\n\tvector<int> original_classes;\n\toriginal_points = vector<PointXYZ>((PointXYZ*)PyArray_DATA(points_array), (PointXYZ*)PyArray_DATA(points_array) + N);\n\toriginal_batches = vector<int>((int*)PyArray_DATA(batches_array), (int*)PyArray_DATA(batches_array) + Nb);\n\tif (use_feature)\n\t\toriginal_features = vector<float>((float*)PyArray_DATA(features_array), (float*)PyArray_DATA(features_array) + N * fdim);\n\tif (use_classes)\n\t\toriginal_classes = vector<int>((int*)PyArray_DATA(classes_array), (int*)PyArray_DATA(classes_array) + N * ldim);\n\n\t// Subsample\n\tvector<PointXYZ> subsampled_points;\n\tvector<float> subsampled_features;\n\tvector<int> subsampled_classes;\n\tvector<int> subsampled_batches;\n\tbatch_grid_subsampling(original_points,\n\t\t\t\t\t\t\tsubsampled_points,\n\t\t\t\t\t\t\toriginal_features,\n\t\t\t\t\t\t\tsubsampled_features,\n\t\t\t\t\t\t\toriginal_classes,\n\t\t\t\t\t\t\tsubsampled_classes,\n\t\t\t\t\t\t\toriginal_batches,\n\t\t\t\t\t\t\tsubsampled_batches,\n\t\t\t\t\t\t\tsampleDl,\n\t\t\t\t\t\t\tmax_p);\n\n\t// Check result\n\tif (subsampled_points.size() < 1)\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error\");\n\t\treturn NULL;\n\t}\n\n\t// Manage outputs\n\t// **************\n\n\t// Dimension of input containers\n\tnpy_intp* point_dims = new npy_intp[2];\n\tpoint_dims[0] = subsampled_points.size();\n\tpoint_dims[1] = 3;\n\tnpy_intp* feature_dims = new npy_intp[2];\n\tfeature_dims[0] = subsampled_points.size();\n\tfeature_dims[1] = fdim;\n\tnpy_intp* classes_dims = new npy_intp[2];\n\tclasses_dims[0] = subsampled_points.size();\n\tclasses_dims[1] = ldim;\n\tnpy_intp* batches_dims = new npy_intp[1];\n\tbatches_dims[0] = Nb;\n\n\t// Create output array\n\tPyObject* res_points_obj = PyArray_SimpleNew(2, point_dims, NPY_FLOAT);\n\tPyObject* res_batches_obj = PyArray_SimpleNew(1, batches_dims, NPY_INT);\n\tPyObject* res_features_obj = NULL;\n\tPyObject* res_classes_obj = NULL;\n\tPyObject* ret = NULL;\n\n\t// Fill output array with values\n\tsize_t size_in_bytes = subsampled_points.size() * 3 * sizeof(float);\n\tmemcpy(PyArray_DATA(res_points_obj), subsampled_points.data(), size_in_bytes);\n\tsize_in_bytes = Nb * sizeof(int);\n\tmemcpy(PyArray_DATA(res_batches_obj), subsampled_batches.data(), size_in_bytes);\n\tif (use_feature)\n\t{\n\t\tsize_in_bytes = subsampled_points.size() * fdim * sizeof(float);\n\t\tres_features_obj = PyArray_SimpleNew(2, feature_dims, NPY_FLOAT);\n\t\tmemcpy(PyArray_DATA(res_features_obj), subsampled_features.data(), size_in_bytes);\n\t}\n\tif (use_classes)\n\t{\n\t\tsize_in_bytes = subsampled_points.size() * ldim * sizeof(int);\n\t\tres_classes_obj = PyArray_SimpleNew(2, classes_dims, NPY_INT);\n\t\tmemcpy(PyArray_DATA(res_classes_obj), subsampled_classes.data(), size_in_bytes);\n\t}\n\n\n\t// Merge results\n\tif (use_feature && use_classes)\n\t\tret = Py_BuildValue(\"NNNN\", res_points_obj, res_batches_obj, res_features_obj, res_classes_obj);\n\telse if (use_feature)\n\t\tret = Py_BuildValue(\"NNN\", res_points_obj, res_batches_obj, res_features_obj);\n\telse if (use_classes)\n\t\tret = Py_BuildValue(\"NNN\", res_points_obj, res_batches_obj, res_classes_obj);\n\telse\n\t\tret = Py_BuildValue(\"NN\", res_points_obj, res_batches_obj);\n\n\t// Clean up\n\t// ********\n\n\tPy_DECREF(points_array);\n\tPy_DECREF(batches_array);\n\tPy_XDECREF(features_array);\n\tPy_XDECREF(classes_array);\n\n\treturn ret;\n}\n\n// Definition of the subsample method\n// ****************************************\n\nstatic PyObject* cloud_subsampling(PyObject* self, PyObject* args, PyObject* keywds)\n{\n\n\t// Manage inputs\n\t// *************\n\n\t// Args containers\n\tPyObject* points_obj = NULL;\n\tPyObject* features_obj = NULL;\n\tPyObject* classes_obj = NULL;\n\n\t// Keywords containers\n\tstatic char* kwlist[] = { \"points\", \"features\", \"classes\", \"sampleDl\", \"method\", \"verbose\", NULL };\n\tfloat sampleDl = 0.1;\n\tconst char* method_buffer = \"barycenters\";\n\tint verbose = 0;\n\n\t// Parse the input  \n\tif (!PyArg_ParseTupleAndKeywords(args, keywds, \"O|$OOfsi\", kwlist, &points_obj, &features_obj, &classes_obj, &sampleDl, &method_buffer, &verbose))\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error parsing arguments\");\n\t\treturn NULL;\n\t}\n\n\t// Get the method argument\n\tstring method(method_buffer);\n\n\t// Interpret method\n\tif (method.compare(\"barycenters\") && method.compare(\"voxelcenters\"))\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error parsing method. Valid method names are \\\"barycenters\\\" and \\\"voxelcenters\\\" \");\n\t\treturn NULL;\n\t}\n\n\t// Check if using features or classes\n\tbool use_feature = true, use_classes = true;\n\tif (features_obj == NULL)\n\t\tuse_feature = false;\n\tif (classes_obj == NULL)\n\t\tuse_classes = false;\n\n\t// Interpret the input objects as numpy arrays.\n\tPyObject* points_array = PyArray_FROM_OTF(points_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tPyObject* features_array = NULL;\n\tPyObject* classes_array = NULL;\n\tif (use_feature)\n\t\tfeatures_array = PyArray_FROM_OTF(features_obj, NPY_FLOAT, NPY_IN_ARRAY);\n\tif (use_classes)\n\t\tclasses_array = PyArray_FROM_OTF(classes_obj, NPY_INT, NPY_IN_ARRAY);\n\n\t// Verify data was load correctly.\n\tif (points_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input points to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (use_feature && features_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input features to numpy arrays of type float32\");\n\t\treturn NULL;\n\t}\n\tif (use_classes && classes_array == NULL)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error converting input classes to numpy arrays of type int32\");\n\t\treturn NULL;\n\t}\n\n\t// Check that the input array respect the dims\n\tif ((int)PyArray_NDIM(points_array) != 2 || (int)PyArray_DIM(points_array, 1) != 3)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : points.shape is not (N, 3)\");\n\t\treturn NULL;\n\t}\n\tif (use_feature && ((int)PyArray_NDIM(features_array) != 2))\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : features.shape is not (N, d)\");\n\t\treturn NULL;\n\t}\n\n\tif (use_classes && (int)PyArray_NDIM(classes_array) > 2)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : classes.shape is not (N,) or (N, d)\");\n\t\treturn NULL;\n\t}\n\n\t// Number of points\n\tint N = (int)PyArray_DIM(points_array, 0);\n\n\t// Dimension of the features\n\tint fdim = 0;\n\tif (use_feature)\n\t\tfdim = (int)PyArray_DIM(features_array, 1);\n\n\t//Dimension of labels\n\tint ldim = 1;\n\tif (use_classes && (int)PyArray_NDIM(classes_array) == 2)\n\t\tldim = (int)PyArray_DIM(classes_array, 1);\n\n\t// Check that the input array respect the number of points\n\tif (use_feature && (int)PyArray_DIM(features_array, 0) != N)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : features.shape is not (N, d)\");\n\t\treturn NULL;\n\t}\n\tif (use_classes && (int)PyArray_DIM(classes_array, 0) != N)\n\t{\n\t\tPy_XDECREF(points_array);\n\t\tPy_XDECREF(classes_array);\n\t\tPy_XDECREF(features_array);\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Wrong dimensions : classes.shape is not (N,) or (N, d)\");\n\t\treturn NULL;\n\t}\n\n\n\t// Call the C++ function\n\t// *********************\n\n\t// Create pyramid\n\tif (verbose > 0)\n\t\tcout << \"Computing cloud pyramid with support points: \" << endl;\n\n\n\t// Convert PyArray to Cloud C++ class\n\tvector<PointXYZ> original_points;\n\tvector<float> original_features;\n\tvector<int> original_classes;\n\toriginal_points = vector<PointXYZ>((PointXYZ*)PyArray_DATA(points_array), (PointXYZ*)PyArray_DATA(points_array) + N);\n\tif (use_feature)\n\t\toriginal_features = vector<float>((float*)PyArray_DATA(features_array), (float*)PyArray_DATA(features_array) + N * fdim);\n\tif (use_classes)\n\t\toriginal_classes = vector<int>((int*)PyArray_DATA(classes_array), (int*)PyArray_DATA(classes_array) + N * ldim);\n\n\t// Subsample\n\tvector<PointXYZ> subsampled_points;\n\tvector<float> subsampled_features;\n\tvector<int> subsampled_classes;\n\tgrid_subsampling(original_points,\n\t\tsubsampled_points,\n\t\toriginal_features,\n\t\tsubsampled_features,\n\t\toriginal_classes,\n\t\tsubsampled_classes,\n\t\tsampleDl,\n\t\tverbose);\n\n\t// Check result\n\tif (subsampled_points.size() < 1)\n\t{\n\t\tPyErr_SetString(PyExc_RuntimeError, \"Error\");\n\t\treturn NULL;\n\t}\n\n\t// Manage outputs\n\t// **************\n\n\t// Dimension of input containers\n\tnpy_intp* point_dims = new npy_intp[2];\n\tpoint_dims[0] = subsampled_points.size();\n\tpoint_dims[1] = 3;\n\tnpy_intp* feature_dims = new npy_intp[2];\n\tfeature_dims[0] = subsampled_points.size();\n\tfeature_dims[1] = fdim;\n\tnpy_intp* classes_dims = new npy_intp[2];\n\tclasses_dims[0] = subsampled_points.size();\n\tclasses_dims[1] = ldim;\n\n\t// Create output array\n\tPyObject* res_points_obj = PyArray_SimpleNew(2, point_dims, NPY_FLOAT);\n\tPyObject* res_features_obj = NULL;\n\tPyObject* res_classes_obj = NULL;\n\tPyObject* ret = NULL;\n\n\t// Fill output array with values\n\tsize_t size_in_bytes = subsampled_points.size() * 3 * sizeof(float);\n\tmemcpy(PyArray_DATA(res_points_obj), subsampled_points.data(), size_in_bytes);\n\tif (use_feature)\n\t{\n\t\tsize_in_bytes = subsampled_points.size() * fdim * sizeof(float);\n\t\tres_features_obj = PyArray_SimpleNew(2, feature_dims, NPY_FLOAT);\n\t\tmemcpy(PyArray_DATA(res_features_obj), subsampled_features.data(), size_in_bytes);\n\t}\n\tif (use_classes)\n\t{\n\t\tsize_in_bytes = subsampled_points.size() * ldim * sizeof(int);\n\t\tres_classes_obj = PyArray_SimpleNew(2, classes_dims, NPY_INT);\n\t\tmemcpy(PyArray_DATA(res_classes_obj), subsampled_classes.data(), size_in_bytes);\n\t}\n\n\n\t// Merge results\n\tif (use_feature && use_classes)\n\t\tret = Py_BuildValue(\"NNN\", res_points_obj, res_features_obj, res_classes_obj);\n\telse if (use_feature)\n\t\tret = Py_BuildValue(\"NN\", res_points_obj, res_features_obj);\n\telse if (use_classes)\n\t\tret = Py_BuildValue(\"NN\", res_points_obj, res_classes_obj);\n\telse\n\t\tret = Py_BuildValue(\"N\", res_points_obj);\n\n\t// Clean up\n\t// ********\n\n\tPy_DECREF(points_array);\n\tPy_XDECREF(features_array);\n\tPy_XDECREF(classes_array);\n\n\treturn ret;\n}"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.cpp",
    "content": "//\n//\n//\t\t0==========================0\n//\t\t|    Local feature test    |\n//\t\t0==========================0\n//\n//\t\tversion 1.0 : \n//\t\t\t> \n//\n//---------------------------------------------------\n//\n//\t\tCloud source :\n//\t\tDefine usefull Functions/Methods\n//\n//----------------------------------------------------\n//\n//\t\tHugues THOMAS - 10/02/2017\n//\n\n\n#include \"cloud.h\"\n\n\n// Getters\n// *******\n\nPointXYZ max_point(std::vector<PointXYZ> points)\n{\n\t// Initialize limits\n\tPointXYZ maxP(points[0]);\n\n\t// Loop over all points\n\tfor (auto p : points)\n\t{\n\t\tif (p.x > maxP.x)\n\t\t\tmaxP.x = p.x;\n\n\t\tif (p.y > maxP.y)\n\t\t\tmaxP.y = p.y;\n\n\t\tif (p.z > maxP.z)\n\t\t\tmaxP.z = p.z;\n\t}\n\n\treturn maxP;\n}\n\nPointXYZ min_point(std::vector<PointXYZ> points)\n{\n\t// Initialize limits\n\tPointXYZ minP(points[0]);\n\n\t// Loop over all points\n\tfor (auto p : points)\n\t{\n\t\tif (p.x < minP.x)\n\t\t\tminP.x = p.x;\n\n\t\tif (p.y < minP.y)\n\t\t\tminP.y = p.y;\n\n\t\tif (p.z < minP.z)\n\t\t\tminP.z = p.z;\n\t}\n\n\treturn minP;\n}"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/cloud/cloud.h",
    "content": "//\n//\n//\t\t0==========================0\n//\t\t|    Local feature test    |\n//\t\t0==========================0\n//\n//\t\tversion 1.0 : \n//\t\t\t> \n//\n//---------------------------------------------------\n//\n//\t\tCloud header\n//\n//----------------------------------------------------\n//\n//\t\tHugues THOMAS - 10/02/2017\n//\n\n\n# pragma once\n\n#include <vector>\n#include <unordered_map>\n#include <map>\n#include <algorithm>\n#include <numeric>\n#include <iostream>\n#include <iomanip>\n#include <cmath>\n\n#include <time.h>\n\n\n\n\n// Point class\n// ***********\n\n\nclass PointXYZ\n{\npublic:\n\n\t// Elements\n\t// ********\n\n\tfloat x, y, z;\n\n\n\t// Methods\n\t// *******\n\t\n\t// Constructor\n\tPointXYZ() { x = 0; y = 0; z = 0; }\n\tPointXYZ(float x0, float y0, float z0) { x = x0; y = y0; z = z0; }\n\t\n\t// array type accessor\n\tfloat operator [] (int i) const\n\t{\n\t\tif (i == 0) return x;\n\t\telse if (i == 1) return y;\n\t\telse return z;\n\t}\n\n\t// opperations\n\tfloat dot(const PointXYZ P) const\n\t{\n\t\treturn x * P.x + y * P.y + z * P.z;\n\t}\n\n\tfloat sq_norm()\n\t{\n\t\treturn x*x + y*y + z*z;\n\t}\n\n\tPointXYZ cross(const PointXYZ P) const\n\t{\n\t\treturn PointXYZ(y*P.z - z*P.y, z*P.x - x*P.z, x*P.y - y*P.x);\n\t}\t\n\n\tPointXYZ& operator+=(const PointXYZ& P)\n\t{\n\t\tx += P.x;\n\t\ty += P.y;\n\t\tz += P.z;\n\t\treturn *this;\n\t}\n\n\tPointXYZ& operator-=(const PointXYZ& P)\n\t{\n\t\tx -= P.x;\n\t\ty -= P.y;\n\t\tz -= P.z;\n\t\treturn *this;\n\t}\n\n\tPointXYZ& operator*=(const float& a)\n\t{\n\t\tx *= a;\n\t\ty *= a;\n\t\tz *= a;\n\t\treturn *this;\n\t}\n};\n\n\n// Point Opperations\n// *****************\n\ninline PointXYZ operator + (const PointXYZ A, const PointXYZ B)\n{\n\treturn PointXYZ(A.x + B.x, A.y + B.y, A.z + B.z);\n}\n\ninline PointXYZ operator - (const PointXYZ A, const PointXYZ B)\n{\n\treturn PointXYZ(A.x - B.x, A.y - B.y, A.z - B.z);\n}\n\ninline PointXYZ operator * (const PointXYZ P, const float a)\n{\n\treturn PointXYZ(P.x * a, P.y * a, P.z * a);\n}\n\ninline PointXYZ operator * (const float a, const PointXYZ P)\n{\n\treturn PointXYZ(P.x * a, P.y * a, P.z * a);\n}\n\ninline std::ostream& operator << (std::ostream& os, const PointXYZ P)\n{\n\treturn os << \"[\" << P.x << \", \" << P.y << \", \" << P.z << \"]\";\n}\n\ninline bool operator == (const PointXYZ A, const PointXYZ B)\n{\n\treturn A.x == B.x && A.y == B.y && A.z == B.z;\n}\n\ninline PointXYZ floor(const PointXYZ P)\n{\n\treturn PointXYZ(std::floor(P.x), std::floor(P.y), std::floor(P.z));\n}\n\n\nPointXYZ max_point(std::vector<PointXYZ> points);\nPointXYZ min_point(std::vector<PointXYZ> points);\n\n\nstruct PointCloud\n{\n\n\tstd::vector<PointXYZ>  pts;\n\n\t// Must return the number of data points\n\tinline size_t kdtree_get_point_count() const { return pts.size(); }\n\n\t// Returns the dim'th component of the idx'th point in the class:\n\t// Since this is inlined and the \"dim\" argument is typically an immediate value, the\n\t//  \"if/else's\" are actually solved at compile time.\n\tinline float kdtree_get_pt(const size_t idx, const size_t dim) const\n\t{\n\t\tif (dim == 0) return pts[idx].x;\n\t\telse if (dim == 1) return pts[idx].y;\n\t\telse return pts[idx].z;\n\t}\n\n\t// Optional bounding-box computation: return false to default to a standard bbox computation loop.\n\t//   Return true if the BBOX was already computed by the class and returned in \"bb\" so it can be avoided to redo it again.\n\t//   Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds)\n\ttemplate <class BBOX>\n\tbool kdtree_get_bbox(BBOX& /* bb */) const { return false; }\n\n};\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "thirdparty/kpconv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp",
    "content": "/***********************************************************************\n * Software License Agreement (BSD License)\n *\n * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.\n * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.\n * Copyright 2011-2016  Jose Luis Blanco (joseluisblancoc@gmail.com).\n *   All rights reserved.\n *\n * THE BSD LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n *\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR\n * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES\n * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.\n * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,\n * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT\n * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\n * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\n * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF\n * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *************************************************************************/\n\n/** \\mainpage nanoflann C++ API documentation\n *  nanoflann is a C++ header-only library for building KD-Trees, mostly\n *  optimized for 2D or 3D point clouds.\n *\n *  nanoflann does not require compiling or installing, just an\n *  #include <nanoflann.hpp> in your code.\n *\n *  See:\n *   - <a href=\"modules.html\" >C++ API organized by modules</a>\n *   - <a href=\"https://github.com/jlblancoc/nanoflann\" >Online README</a>\n *   - <a href=\"http://jlblancoc.github.io/nanoflann/\" >Doxygen\n * documentation</a>\n */\n\n#ifndef NANOFLANN_HPP_\n#define NANOFLANN_HPP_\n\n#include <algorithm>\n#include <array>\n#include <cassert>\n#include <cmath>   // for abs()\n#include <cstdio>  // for fwrite()\n#include <cstdlib> // for abs()\n#include <functional>\n#include <limits> // std::reference_wrapper\n#include <stdexcept>\n#include <vector>\n\n/** Library version: 0xMmP (M=Major,m=minor,P=patch) */\n#define NANOFLANN_VERSION 0x130\n\n// Avoid conflicting declaration of min/max macros in windows headers\n#if !defined(NOMINMAX) &&                                                      \\\n    (defined(_WIN32) || defined(_WIN32_) || defined(WIN32) || defined(_WIN64))\n#define NOMINMAX\n#ifdef max\n#undef max\n#undef min\n#endif\n#endif\n\nnamespace nanoflann {\n/** @addtogroup nanoflann_grp nanoflann C++ library for ANN\n *  @{ */\n\n/** the PI constant (required to avoid MSVC missing symbols) */\ntemplate <typename T> T pi_const() {\n  return static_cast<T>(3.14159265358979323846);\n}\n\n/**\n * Traits if object is resizable and assignable (typically has a resize | assign\n * method)\n */\ntemplate <typename T, typename = int> struct has_resize : std::false_type {};\n\ntemplate <typename T>\nstruct has_resize<T, decltype((void)std::declval<T>().resize(1), 0)>\n    : std::true_type {};\n\ntemplate <typename T, typename = int> struct has_assign : std::false_type {};\n\ntemplate <typename T>\nstruct has_assign<T, decltype((void)std::declval<T>().assign(1, 0), 0)>\n    : std::true_type {};\n\n/**\n * Free function to resize a resizable object\n */\ntemplate <typename Container>\ninline typename std::enable_if<has_resize<Container>::value, void>::type\nresize(Container &c, const size_t nElements) {\n  c.resize(nElements);\n}\n\n/**\n * Free function that has no effects on non resizable containers (e.g.\n * std::array) It raises an exception if the expected size does not match\n */\ntemplate <typename Container>\ninline typename std::enable_if<!has_resize<Container>::value, void>::type\nresize(Container &c, const size_t nElements) {\n  if (nElements != c.size())\n    throw std::logic_error(\"Try to change the size of a std::array.\");\n}\n\n/**\n * Free function to assign to a container\n */\ntemplate <typename Container, typename T>\ninline typename std::enable_if<has_assign<Container>::value, void>::type\nassign(Container &c, const size_t nElements, const T &value) {\n  c.assign(nElements, value);\n}\n\n/**\n * Free function to assign to a std::array\n */\ntemplate <typename Container, typename T>\ninline typename std::enable_if<!has_assign<Container>::value, void>::type\nassign(Container &c, const size_t nElements, const T &value) {\n  for (size_t i = 0; i < nElements; i++)\n    c[i] = value;\n}\n\n/** @addtogroup result_sets_grp Result set classes\n *  @{ */\ntemplate <typename _DistanceType, typename _IndexType = size_t,\n          typename _CountType = size_t>\nclass KNNResultSet {\npublic:\n  typedef _DistanceType DistanceType;\n  typedef _IndexType IndexType;\n  typedef _CountType CountType;\n\nprivate:\n  IndexType *indices;\n  DistanceType *dists;\n  CountType capacity;\n  CountType count;\n\npublic:\n  inline KNNResultSet(CountType capacity_)\n      : indices(0), dists(0), capacity(capacity_), count(0) {}\n\n  inline void init(IndexType *indices_, DistanceType *dists_) {\n    indices = indices_;\n    dists = dists_;\n    count = 0;\n    if (capacity)\n      dists[capacity - 1] = (std::numeric_limits<DistanceType>::max)();\n  }\n\n  inline CountType size() const { return count; }\n\n  inline bool full() const { return count == capacity; }\n\n  /**\n   * Called during search to add an element matching the criteria.\n   * @return true if the search should be continued, false if the results are\n   * sufficient\n   */\n  inline bool addPoint(DistanceType dist, IndexType index) {\n    CountType i;\n    for (i = count; i > 0; --i) {\n#ifdef NANOFLANN_FIRST_MATCH // If defined and two points have the same\n                             // distance, the one with the lowest-index will be\n                             // returned first.\n      if ((dists[i - 1] > dist) ||\n          ((dist == dists[i - 1]) && (indices[i - 1] > index))) {\n#else\n      if (dists[i - 1] > dist) {\n#endif\n        if (i < capacity) {\n          dists[i] = dists[i - 1];\n          indices[i] = indices[i - 1];\n        }\n      } else\n        break;\n    }\n    if (i < capacity) {\n      dists[i] = dist;\n      indices[i] = index;\n    }\n    if (count < capacity)\n      count++;\n\n    // tell caller that the search shall continue\n    return true;\n  }\n\n  inline DistanceType worstDist() const { return dists[capacity - 1]; }\n};\n\n/** operator \"<\" for std::sort() */\nstruct IndexDist_Sorter {\n  /** PairType will be typically: std::pair<IndexType,DistanceType> */\n  template <typename PairType>\n  inline bool operator()(const PairType &p1, const PairType &p2) const {\n    return p1.second < p2.second;\n  }\n};\n\n/**\n * A result-set class used when performing a radius based search.\n */\ntemplate <typename _DistanceType, typename _IndexType = size_t>\nclass RadiusResultSet {\npublic:\n  typedef _DistanceType DistanceType;\n  typedef _IndexType IndexType;\n\npublic:\n  const DistanceType radius;\n\n  std::vector<std::pair<IndexType, DistanceType>> &m_indices_dists;\n\n  inline RadiusResultSet(\n      DistanceType radius_,\n      std::vector<std::pair<IndexType, DistanceType>> &indices_dists)\n      : radius(radius_), m_indices_dists(indices_dists) {\n    init();\n  }\n\n  inline void init() { clear(); }\n  inline void clear() { m_indices_dists.clear(); }\n\n  inline size_t size() const { return m_indices_dists.size(); }\n\n  inline bool full() const { return true; }\n\n  /**\n   * Called during search to add an element matching the criteria.\n   * @return true if the search should be continued, false if the results are\n   * sufficient\n   */\n  inline bool addPoint(DistanceType dist, IndexType index) {\n    if (dist < radius)\n      m_indices_dists.push_back(std::make_pair(index, dist));\n    return true;\n  }\n\n  inline DistanceType worstDist() const { return radius; }\n\n  /**\n   * Find the worst result (furtherest neighbor) without copying or sorting\n   * Pre-conditions: size() > 0\n   */\n  std::pair<IndexType, DistanceType> worst_item() const {\n    if (m_indices_dists.empty())\n      throw std::runtime_error(\"Cannot invoke RadiusResultSet::worst_item() on \"\n                               \"an empty list of results.\");\n    typedef\n        typename std::vector<std::pair<IndexType, DistanceType>>::const_iterator\n            DistIt;\n    DistIt it = std::max_element(m_indices_dists.begin(), m_indices_dists.end(),\n                                 IndexDist_Sorter());\n    return *it;\n  }\n};\n\n/** @} */\n\n/** @addtogroup loadsave_grp Load/save auxiliary functions\n * @{ */\ntemplate <typename T>\nvoid save_value(FILE *stream, const T &value, size_t count = 1) {\n  fwrite(&value, sizeof(value), count, stream);\n}\n\ntemplate <typename T>\nvoid save_value(FILE *stream, const std::vector<T> &value) {\n  size_t size = value.size();\n  fwrite(&size, sizeof(size_t), 1, stream);\n  fwrite(&value[0], sizeof(T), size, stream);\n}\n\ntemplate <typename T>\nvoid load_value(FILE *stream, T &value, size_t count = 1) {\n  size_t read_cnt = fread(&value, sizeof(value), count, stream);\n  if (read_cnt != count) {\n    throw std::runtime_error(\"Cannot read from file\");\n  }\n}\n\ntemplate <typename T> void load_value(FILE *stream, std::vector<T> &value) {\n  size_t size;\n  size_t read_cnt = fread(&size, sizeof(size_t), 1, stream);\n  if (read_cnt != 1) {\n    throw std::runtime_error(\"Cannot read from file\");\n  }\n  value.resize(size);\n  read_cnt = fread(&value[0], sizeof(T), size, stream);\n  if (read_cnt != size) {\n    throw std::runtime_error(\"Cannot read from file\");\n  }\n}\n/** @} */\n\n/** @addtogroup metric_grp Metric (distance) classes\n * @{ */\n\nstruct Metric {};\n\n/** Manhattan distance functor (generic version, optimized for\n * high-dimensionality data sets). Corresponding distance traits:\n * nanoflann::metric_L1 \\tparam T Type of the elements (e.g. double, float,\n * uint8_t) \\tparam _DistanceType Type of distance variables (must be signed)\n * (e.g. float, double, int64_t)\n */\ntemplate <class T, class DataSource, typename _DistanceType = T>\nstruct L1_Adaptor {\n  typedef T ElementType;\n  typedef _DistanceType DistanceType;\n\n  const DataSource &data_source;\n\n  L1_Adaptor(const DataSource &_data_source) : data_source(_data_source) {}\n\n  inline DistanceType evalMetric(const T *a, const size_t b_idx, size_t size,\n                                 DistanceType worst_dist = -1) const {\n    DistanceType result = DistanceType();\n    const T *last = a + size;\n    const T *lastgroup = last - 3;\n    size_t d = 0;\n\n    /* Process 4 items with each loop for efficiency. */\n    while (a < lastgroup) {\n      const DistanceType diff0 =\n          std::abs(a[0] - data_source.kdtree_get_pt(b_idx, d++));\n      const DistanceType diff1 =\n          std::abs(a[1] - data_source.kdtree_get_pt(b_idx, d++));\n      const DistanceType diff2 =\n          std::abs(a[2] - data_source.kdtree_get_pt(b_idx, d++));\n      const DistanceType diff3 =\n          std::abs(a[3] - data_source.kdtree_get_pt(b_idx, d++));\n      result += diff0 + diff1 + diff2 + diff3;\n      a += 4;\n      if ((worst_dist > 0) && (result > worst_dist)) {\n        return result;\n      }\n    }\n    /* Process last 0-3 components.  Not needed for standard vector lengths. */\n    while (a < last) {\n      result += std::abs(*a++ - data_source.kdtree_get_pt(b_idx, d++));\n    }\n    return result;\n  }\n\n  template <typename U, typename V>\n  inline DistanceType accum_dist(const U a, const V b, const size_t) const {\n    return std::abs(a - b);\n  }\n};\n\n/** Squared Euclidean distance functor (generic version, optimized for\n * high-dimensionality data sets). Corresponding distance traits:\n * nanoflann::metric_L2 \\tparam T Type of the elements (e.g. double, float,\n * uint8_t) \\tparam _DistanceType Type of distance variables (must be signed)\n * (e.g. float, double, int64_t)\n */\ntemplate <class T, class DataSource, typename _DistanceType = T>\nstruct L2_Adaptor {\n  typedef T ElementType;\n  typedef _DistanceType DistanceType;\n\n  const DataSource &data_source;\n\n  L2_Adaptor(const DataSource &_data_source) : data_source(_data_source) {}\n\n  inline DistanceType evalMetric(const T *a, const size_t b_idx, size_t size,\n                                 DistanceType worst_dist = -1) const {\n    DistanceType result = DistanceType();\n    const T *last = a + size;\n    const T *lastgroup = last - 3;\n    size_t d = 0;\n\n    /* Process 4 items with each loop for efficiency. */\n    while (a < lastgroup) {\n      const DistanceType diff0 = a[0] - data_source.kdtree_get_pt(b_idx, d++);\n      const DistanceType diff1 = a[1] - data_source.kdtree_get_pt(b_idx, d++);\n      const DistanceType diff2 = a[2] - data_source.kdtree_get_pt(b_idx, d++);\n      const DistanceType diff3 = a[3] - data_source.kdtree_get_pt(b_idx, d++);\n      result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;\n      a += 4;\n      if ((worst_dist > 0) && (result > worst_dist)) {\n        return result;\n      }\n    }\n    /* Process last 0-3 components.  Not needed for standard vector lengths. */\n    while (a < last) {\n      const DistanceType diff0 = *a++ - data_source.kdtree_get_pt(b_idx, d++);\n      result += diff0 * diff0;\n    }\n    return result;\n  }\n\n  template <typename U, typename V>\n  inline DistanceType accum_dist(const U a, const V b, const size_t) const {\n    return (a - b) * (a - b);\n  }\n};\n\n/** Squared Euclidean (L2) distance functor (suitable for low-dimensionality\n * datasets, like 2D or 3D point clouds) Corresponding distance traits:\n * nanoflann::metric_L2_Simple \\tparam T Type of the elements (e.g. double,\n * float, uint8_t) \\tparam _DistanceType Type of distance variables (must be\n * signed) (e.g. float, double, int64_t)\n */\ntemplate <class T, class DataSource, typename _DistanceType = T>\nstruct L2_Simple_Adaptor {\n  typedef T ElementType;\n  typedef _DistanceType DistanceType;\n\n  const DataSource &data_source;\n\n  L2_Simple_Adaptor(const DataSource &_data_source)\n      : data_source(_data_source) {}\n\n  inline DistanceType evalMetric(const T *a, const size_t b_idx,\n                                 size_t size) const {\n    DistanceType result = DistanceType();\n    for (size_t i = 0; i < size; ++i) {\n      const DistanceType diff = a[i] - data_source.kdtree_get_pt(b_idx, i);\n      result += diff * diff;\n    }\n    return result;\n  }\n\n  template <typename U, typename V>\n  inline DistanceType accum_dist(const U a, const V b, const size_t) const {\n    return (a - b) * (a - b);\n  }\n};\n\n/** SO2 distance functor\n *  Corresponding distance traits: nanoflann::metric_SO2\n * \\tparam T Type of the elements (e.g. double, float)\n * \\tparam _DistanceType Type of distance variables (must be signed) (e.g.\n * float, double) orientation is constrained to be in [-pi, pi]\n */\ntemplate <class T, class DataSource, typename _DistanceType = T>\nstruct SO2_Adaptor {\n  typedef T ElementType;\n  typedef _DistanceType DistanceType;\n\n  const DataSource &data_source;\n\n  SO2_Adaptor(const DataSource &_data_source) : data_source(_data_source) {}\n\n  inline DistanceType evalMetric(const T *a, const size_t b_idx,\n                                 size_t size) const {\n    return accum_dist(a[size - 1], data_source.kdtree_get_pt(b_idx, size - 1),\n                      size - 1);\n  }\n\n  /** Note: this assumes that input angles are already in the range [-pi,pi] */\n  template <typename U, typename V>\n  inline DistanceType accum_dist(const U a, const V b, const size_t) const {\n    DistanceType result = DistanceType(), PI = pi_const<DistanceType>();\n    result = b - a;\n    if (result > PI)\n      result -= 2 * PI;\n    else if (result < -PI)\n      result += 2 * PI;\n    return result;\n  }\n};\n\n/** SO3 distance functor (Uses L2_Simple)\n *  Corresponding distance traits: nanoflann::metric_SO3\n * \\tparam T Type of the elements (e.g. double, float)\n * \\tparam _DistanceType Type of distance variables (must be signed) (e.g.\n * float, double)\n */\ntemplate <class T, class DataSource, typename _DistanceType = T>\nstruct SO3_Adaptor {\n  typedef T ElementType;\n  typedef _DistanceType DistanceType;\n\n  L2_Simple_Adaptor<T, DataSource> distance_L2_Simple;\n\n  SO3_Adaptor(const DataSource &_data_source)\n      : distance_L2_Simple(_data_source) {}\n\n  inline DistanceType evalMetric(const T *a, const size_t b_idx,\n                                 size_t size) const {\n    return distance_L2_Simple.evalMetric(a, b_idx, size);\n  }\n\n  template <typename U, typename V>\n  inline DistanceType accum_dist(const U a, const V b, const size_t idx) const {\n    return distance_L2_Simple.accum_dist(a, b, idx);\n  }\n};\n\n/** Metaprogramming helper traits class for the L1 (Manhattan) metric */\nstruct metric_L1 : public Metric {\n  template <class T, class DataSource> struct traits {\n    typedef L1_Adaptor<T, DataSource> distance_t;\n  };\n};\n/** Metaprogramming helper traits class for the L2 (Euclidean) metric */\nstruct metric_L2 : public Metric {\n  template <class T, class DataSource> struct traits {\n    typedef L2_Adaptor<T, DataSource> distance_t;\n  };\n};\n/** Metaprogramming helper traits class for the L2_simple (Euclidean) metric */\nstruct metric_L2_Simple : public Metric {\n  template <class T, class DataSource> struct traits {\n    typedef L2_Simple_Adaptor<T, DataSource> distance_t;\n  };\n};\n/** Metaprogramming helper traits class for the SO3_InnerProdQuat metric */\nstruct metric_SO2 : public Metric {\n  template <class T, class DataSource> struct traits {\n    typedef SO2_Adaptor<T, DataSource> distance_t;\n  };\n};\n/** Metaprogramming helper traits class for the SO3_InnerProdQuat metric */\nstruct metric_SO3 : public Metric {\n  template <class T, class DataSource> struct traits {\n    typedef SO3_Adaptor<T, DataSource> distance_t;\n  };\n};\n\n/** @} */\n\n/** @addtogroup param_grp Parameter structs\n * @{ */\n\n/**  Parameters (see README.md) */\nstruct KDTreeSingleIndexAdaptorParams {\n  KDTreeSingleIndexAdaptorParams(size_t _leaf_max_size = 10)\n      : leaf_max_size(_leaf_max_size) {}\n\n  size_t leaf_max_size;\n};\n\n/** Search options for KDTreeSingleIndexAdaptor::findNeighbors() */\nstruct SearchParams {\n  /** Note: The first argument (checks_IGNORED_) is ignored, but kept for\n   * compatibility with the FLANN interface */\n  SearchParams(int checks_IGNORED_ = 32, float eps_ = 0, bool sorted_ = true)\n      : checks(checks_IGNORED_), eps(eps_), sorted(sorted_) {}\n\n  int checks;  //!< Ignored parameter (Kept for compatibility with the FLANN\n               //!< interface).\n  float eps;   //!< search for eps-approximate neighbours (default: 0)\n  bool sorted; //!< only for radius search, require neighbours sorted by\n               //!< distance (default: true)\n};\n/** @} */\n\n/** @addtogroup memalloc_grp Memory allocation\n * @{ */\n\n/**\n * Allocates (using C's malloc) a generic type T.\n *\n * Params:\n *     count = number of instances to allocate.\n * Returns: pointer (of type T*) to memory buffer\n */\ntemplate <typename T> inline T *allocate(size_t count = 1) {\n  T *mem = static_cast<T *>(::malloc(sizeof(T) * count));\n  return mem;\n}\n\n/**\n * Pooled storage allocator\n *\n * The following routines allow for the efficient allocation of storage in\n * small chunks from a specified pool.  Rather than allowing each structure\n * to be freed individually, an entire pool of storage is freed at once.\n * This method has two advantages over just using malloc() and free().  First,\n * it is far more efficient for allocating small objects, as there is\n * no overhead for remembering all the information needed to free each\n * object or consolidating fragmented memory.  Second, the decision about\n * how long to keep an object is made at the time of allocation, and there\n * is no need to track down all the objects to free them.\n *\n */\n\nconst size_t WORDSIZE = 16;\nconst size_t BLOCKSIZE = 8192;\n\nclass PooledAllocator {\n  /* We maintain memory alignment to word boundaries by requiring that all\n      allocations be in multiples of the machine wordsize.  */\n  /* Size of machine word in bytes.  Must be power of 2. */\n  /* Minimum number of bytes requested at a time from\tthe system.  Must be\n   * multiple of WORDSIZE. */\n\n  size_t remaining; /* Number of bytes left in current block of storage. */\n  void *base;       /* Pointer to base of current block of storage. */\n  void *loc;        /* Current location in block to next allocate memory. */\n\n  void internal_init() {\n    remaining = 0;\n    base = NULL;\n    usedMemory = 0;\n    wastedMemory = 0;\n  }\n\npublic:\n  size_t usedMemory;\n  size_t wastedMemory;\n\n  /**\n      Default constructor. Initializes a new pool.\n   */\n  PooledAllocator() { internal_init(); }\n\n  /**\n   * Destructor. Frees all the memory allocated in this pool.\n   */\n  ~PooledAllocator() { free_all(); }\n\n  /** Frees all allocated memory chunks */\n  void free_all() {\n    while (base != NULL) {\n      void *prev =\n          *(static_cast<void **>(base)); /* Get pointer to prev block. */\n      ::free(base);\n      base = prev;\n    }\n    internal_init();\n  }\n\n  /**\n   * Returns a pointer to a piece of new memory of the given size in bytes\n   * allocated from the pool.\n   */\n  void *malloc(const size_t req_size) {\n    /* Round size up to a multiple of wordsize.  The following expression\n        only works for WORDSIZE that is a power of 2, by masking last bits of\n        incremented size to zero.\n     */\n    const size_t size = (req_size + (WORDSIZE - 1)) & ~(WORDSIZE - 1);\n\n    /* Check whether a new block must be allocated.  Note that the first word\n        of a block is reserved for a pointer to the previous block.\n     */\n    if (size > remaining) {\n\n      wastedMemory += remaining;\n\n      /* Allocate new storage. */\n      const size_t blocksize =\n          (size + sizeof(void *) + (WORDSIZE - 1) > BLOCKSIZE)\n              ? size + sizeof(void *) + (WORDSIZE - 1)\n              : BLOCKSIZE;\n\n      // use the standard C malloc to allocate memory\n      void *m = ::malloc(blocksize);\n      if (!m) {\n        fprintf(stderr, \"Failed to allocate memory.\\n\");\n        return NULL;\n      }\n\n      /* Fill first word of new block with pointer to previous block. */\n      static_cast<void **>(m)[0] = base;\n      base = m;\n\n      size_t shift = 0;\n      // int size_t = (WORDSIZE - ( (((size_t)m) + sizeof(void*)) &\n      // (WORDSIZE-1))) & (WORDSIZE-1);\n\n      remaining = blocksize - sizeof(void *) - shift;\n      loc = (static_cast<char *>(m) + sizeof(void *) + shift);\n    }\n    void *rloc = loc;\n    loc = static_cast<char *>(loc) + size;\n    remaining -= size;\n\n    usedMemory += size;\n\n    return rloc;\n  }\n\n  /**\n   * Allocates (using this pool) a generic type T.\n   *\n   * Params:\n   *     count = number of instances to allocate.\n   * Returns: pointer (of type T*) to memory buffer\n   */\n  template <typename T> T *allocate(const size_t count = 1) {\n    T *mem = static_cast<T *>(this->malloc(sizeof(T) * count));\n    return mem;\n  }\n};\n/** @} */\n\n/** @addtogroup nanoflann_metaprog_grp Auxiliary metaprogramming stuff\n * @{ */\n\n/** Used to declare fixed-size arrays when DIM>0, dynamically-allocated vectors\n * when DIM=-1. Fixed size version for a generic DIM:\n */\ntemplate <int DIM, typename T> struct array_or_vector_selector {\n  typedef std::array<T, DIM> container_t;\n};\n/** Dynamic size version */\ntemplate <typename T> struct array_or_vector_selector<-1, T> {\n  typedef std::vector<T> container_t;\n};\n\n/** @} */\n\n/** kd-tree base-class\n *\n * Contains the member functions common to the classes KDTreeSingleIndexAdaptor\n * and KDTreeSingleIndexDynamicAdaptor_.\n *\n * \\tparam Derived The name of the class which inherits this class.\n * \\tparam DatasetAdaptor The user-provided adaptor (see comments above).\n * \\tparam Distance The distance metric to use, these are all classes derived\n * from nanoflann::Metric \\tparam DIM Dimensionality of data points (e.g. 3 for\n * 3D points) \\tparam IndexType Will be typically size_t or int\n */\n\ntemplate <class Derived, typename Distance, class DatasetAdaptor, int DIM = -1,\n          typename IndexType = size_t>\nclass KDTreeBaseClass {\n\npublic:\n  /** Frees the previously-built index. Automatically called within\n   * buildIndex(). */\n  void freeIndex(Derived &obj) {\n    obj.pool.free_all();\n    obj.root_node = NULL;\n    obj.m_size_at_index_build = 0;\n  }\n\n  typedef typename Distance::ElementType ElementType;\n  typedef typename Distance::DistanceType DistanceType;\n\n  /*--------------------- Internal Data Structures --------------------------*/\n  struct Node {\n    /** Union used because a node can be either a LEAF node or a non-leaf node,\n     * so both data fields are never used simultaneously */\n    union {\n      struct leaf {\n        IndexType left, right; //!< Indices of points in leaf node\n      } lr;\n      struct nonleaf {\n        int divfeat;                  //!< Dimension used for subdivision.\n        DistanceType divlow, divhigh; //!< The values used for subdivision.\n      } sub;\n    } node_type;\n    Node *child1, *child2; //!< Child nodes (both=NULL mean its a leaf node)\n  };\n\n  typedef Node *NodePtr;\n\n  struct Interval {\n    ElementType low, high;\n  };\n\n  /**\n   *  Array of indices to vectors in the dataset.\n   */\n  std::vector<IndexType> vind;\n\n  NodePtr root_node;\n\n  size_t m_leaf_max_size;\n\n  size_t m_size;                //!< Number of current points in the dataset\n  size_t m_size_at_index_build; //!< Number of points in the dataset when the\n                                //!< index was built\n  int dim;                      //!< Dimensionality of each data point\n\n  /** Define \"BoundingBox\" as a fixed-size or variable-size container depending\n   * on \"DIM\" */\n  typedef\n      typename array_or_vector_selector<DIM, Interval>::container_t BoundingBox;\n\n  /** Define \"distance_vector_t\" as a fixed-size or variable-size container\n   * depending on \"DIM\" */\n  typedef typename array_or_vector_selector<DIM, DistanceType>::container_t\n      distance_vector_t;\n\n  /** The KD-tree used to find neighbours */\n\n  BoundingBox root_bbox;\n\n  /**\n   * Pooled memory allocator.\n   *\n   * Using a pooled memory allocator is more efficient\n   * than allocating memory directly when there is a large\n   * number small of memory allocations.\n   */\n  PooledAllocator pool;\n\n  /** Returns number of points in dataset  */\n  size_t size(const Derived &obj) const { return obj.m_size; }\n\n  /** Returns the length of each point in the dataset */\n  size_t veclen(const Derived &obj) {\n    return static_cast<size_t>(DIM > 0 ? DIM : obj.dim);\n  }\n\n  /// Helper accessor to the dataset points:\n  inline ElementType dataset_get(const Derived &obj, size_t idx,\n                                 int component) const {\n    return obj.dataset.kdtree_get_pt(idx, component);\n  }\n\n  /**\n   * Computes the inde memory usage\n   * Returns: memory used by the index\n   */\n  size_t usedMemory(Derived &obj) {\n    return obj.pool.usedMemory + obj.pool.wastedMemory +\n           obj.dataset.kdtree_get_point_count() *\n               sizeof(IndexType); // pool memory and vind array memory\n  }\n\n  void computeMinMax(const Derived &obj, IndexType *ind, IndexType count,\n                     int element, ElementType &min_elem,\n                     ElementType &max_elem) {\n    min_elem = dataset_get(obj, ind[0], element);\n    max_elem = dataset_get(obj, ind[0], element);\n    for (IndexType i = 1; i < count; ++i) {\n      ElementType val = dataset_get(obj, ind[i], element);\n      if (val < min_elem)\n        min_elem = val;\n      if (val > max_elem)\n        max_elem = val;\n    }\n  }\n\n  /**\n   * Create a tree node that subdivides the list of vecs from vind[first]\n   * to vind[last].  The routine is called recursively on each sublist.\n   *\n   * @param left index of the first vector\n   * @param right index of the last vector\n   */\n  NodePtr divideTree(Derived &obj, const IndexType left, const IndexType right,\n                     BoundingBox &bbox) {\n    NodePtr node = obj.pool.template allocate<Node>(); // allocate memory\n\n    /* If too few exemplars remain, then make this a leaf node. */\n    if ((right - left) <= static_cast<IndexType>(obj.m_leaf_max_size)) {\n      node->child1 = node->child2 = NULL; /* Mark as leaf node. */\n      node->node_type.lr.left = left;\n      node->node_type.lr.right = right;\n\n      // compute bounding-box of leaf points\n      for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n        bbox[i].low = dataset_get(obj, obj.vind[left], i);\n        bbox[i].high = dataset_get(obj, obj.vind[left], i);\n      }\n      for (IndexType k = left + 1; k < right; ++k) {\n        for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n          if (bbox[i].low > dataset_get(obj, obj.vind[k], i))\n            bbox[i].low = dataset_get(obj, obj.vind[k], i);\n          if (bbox[i].high < dataset_get(obj, obj.vind[k], i))\n            bbox[i].high = dataset_get(obj, obj.vind[k], i);\n        }\n      }\n    } else {\n      IndexType idx;\n      int cutfeat;\n      DistanceType cutval;\n      middleSplit_(obj, &obj.vind[0] + left, right - left, idx, cutfeat, cutval,\n                   bbox);\n\n      node->node_type.sub.divfeat = cutfeat;\n\n      BoundingBox left_bbox(bbox);\n      left_bbox[cutfeat].high = cutval;\n      node->child1 = divideTree(obj, left, left + idx, left_bbox);\n\n      BoundingBox right_bbox(bbox);\n      right_bbox[cutfeat].low = cutval;\n      node->child2 = divideTree(obj, left + idx, right, right_bbox);\n\n      node->node_type.sub.divlow = left_bbox[cutfeat].high;\n      node->node_type.sub.divhigh = right_bbox[cutfeat].low;\n\n      for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n        bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);\n        bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);\n      }\n    }\n\n    return node;\n  }\n\n  void middleSplit_(Derived &obj, IndexType *ind, IndexType count,\n                    IndexType &index, int &cutfeat, DistanceType &cutval,\n                    const BoundingBox &bbox) {\n    const DistanceType EPS = static_cast<DistanceType>(0.00001);\n    ElementType max_span = bbox[0].high - bbox[0].low;\n    for (int i = 1; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n      ElementType span = bbox[i].high - bbox[i].low;\n      if (span > max_span) {\n        max_span = span;\n      }\n    }\n    ElementType max_spread = -1;\n    cutfeat = 0;\n    for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n      ElementType span = bbox[i].high - bbox[i].low;\n      if (span > (1 - EPS) * max_span) {\n        ElementType min_elem, max_elem;\n        computeMinMax(obj, ind, count, i, min_elem, max_elem);\n        ElementType spread = max_elem - min_elem;\n        ;\n        if (spread > max_spread) {\n          cutfeat = i;\n          max_spread = spread;\n        }\n      }\n    }\n    // split in the middle\n    DistanceType split_val = (bbox[cutfeat].low + bbox[cutfeat].high) / 2;\n    ElementType min_elem, max_elem;\n    computeMinMax(obj, ind, count, cutfeat, min_elem, max_elem);\n\n    if (split_val < min_elem)\n      cutval = min_elem;\n    else if (split_val > max_elem)\n      cutval = max_elem;\n    else\n      cutval = split_val;\n\n    IndexType lim1, lim2;\n    planeSplit(obj, ind, count, cutfeat, cutval, lim1, lim2);\n\n    if (lim1 > count / 2)\n      index = lim1;\n    else if (lim2 < count / 2)\n      index = lim2;\n    else\n      index = count / 2;\n  }\n\n  /**\n   *  Subdivide the list of points by a plane perpendicular on axe corresponding\n   *  to the 'cutfeat' dimension at 'cutval' position.\n   *\n   *  On return:\n   *  dataset[ind[0..lim1-1]][cutfeat]<cutval\n   *  dataset[ind[lim1..lim2-1]][cutfeat]==cutval\n   *  dataset[ind[lim2..count]][cutfeat]>cutval\n   */\n  void planeSplit(Derived &obj, IndexType *ind, const IndexType count,\n                  int cutfeat, DistanceType &cutval, IndexType &lim1,\n                  IndexType &lim2) {\n    /* Move vector indices for left subtree to front of list. */\n    IndexType left = 0;\n    IndexType right = count - 1;\n    for (;;) {\n      while (left <= right && dataset_get(obj, ind[left], cutfeat) < cutval)\n        ++left;\n      while (right && left <= right &&\n             dataset_get(obj, ind[right], cutfeat) >= cutval)\n        --right;\n      if (left > right || !right)\n        break; // \"!right\" was added to support unsigned Index types\n      std::swap(ind[left], ind[right]);\n      ++left;\n      --right;\n    }\n    /* If either list is empty, it means that all remaining features\n     * are identical. Split in the middle to maintain a balanced tree.\n     */\n    lim1 = left;\n    right = count - 1;\n    for (;;) {\n      while (left <= right && dataset_get(obj, ind[left], cutfeat) <= cutval)\n        ++left;\n      while (right && left <= right &&\n             dataset_get(obj, ind[right], cutfeat) > cutval)\n        --right;\n      if (left > right || !right)\n        break; // \"!right\" was added to support unsigned Index types\n      std::swap(ind[left], ind[right]);\n      ++left;\n      --right;\n    }\n    lim2 = left;\n  }\n\n  DistanceType computeInitialDistances(const Derived &obj,\n                                       const ElementType *vec,\n                                       distance_vector_t &dists) const {\n    assert(vec);\n    DistanceType distsq = DistanceType();\n\n    for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) {\n      if (vec[i] < obj.root_bbox[i].low) {\n        dists[i] = obj.distance.accum_dist(vec[i], obj.root_bbox[i].low, i);\n        distsq += dists[i];\n      }\n      if (vec[i] > obj.root_bbox[i].high) {\n        dists[i] = obj.distance.accum_dist(vec[i], obj.root_bbox[i].high, i);\n        distsq += dists[i];\n      }\n    }\n    return distsq;\n  }\n\n  void save_tree(Derived &obj, FILE *stream, NodePtr tree) {\n    save_value(stream, *tree);\n    if (tree->child1 != NULL) {\n      save_tree(obj, stream, tree->child1);\n    }\n    if (tree->child2 != NULL) {\n      save_tree(obj, stream, tree->child2);\n    }\n  }\n\n  void load_tree(Derived &obj, FILE *stream, NodePtr &tree) {\n    tree = obj.pool.template allocate<Node>();\n    load_value(stream, *tree);\n    if (tree->child1 != NULL) {\n      load_tree(obj, stream, tree->child1);\n    }\n    if (tree->child2 != NULL) {\n      load_tree(obj, stream, tree->child2);\n    }\n  }\n\n  /**  Stores the index in a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so when\n   * loading the index object it must be constructed associated to the same\n   * source of data points used while building it. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void saveIndex_(Derived &obj, FILE *stream) {\n    save_value(stream, obj.m_size);\n    save_value(stream, obj.dim);\n    save_value(stream, obj.root_bbox);\n    save_value(stream, obj.m_leaf_max_size);\n    save_value(stream, obj.vind);\n    save_tree(obj, stream, obj.root_node);\n  }\n\n  /**  Loads a previous index from a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so the\n   * index object must be constructed associated to the same source of data\n   * points used while building the index. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void loadIndex_(Derived &obj, FILE *stream) {\n    load_value(stream, obj.m_size);\n    load_value(stream, obj.dim);\n    load_value(stream, obj.root_bbox);\n    load_value(stream, obj.m_leaf_max_size);\n    load_value(stream, obj.vind);\n    load_tree(obj, stream, obj.root_node);\n  }\n};\n\n/** @addtogroup kdtrees_grp KD-tree classes and adaptors\n * @{ */\n\n/** kd-tree static index\n *\n * Contains the k-d trees and other information for indexing a set of points\n * for nearest-neighbor matching.\n *\n *  The class \"DatasetAdaptor\" must provide the following interface (can be\n * non-virtual, inlined methods):\n *\n *  \\code\n *   // Must return the number of data poins\n *   inline size_t kdtree_get_point_count() const { ... }\n *\n *\n *   // Must return the dim'th component of the idx'th point in the class:\n *   inline T kdtree_get_pt(const size_t idx, const size_t dim) const { ... }\n *\n *   // Optional bounding-box computation: return false to default to a standard\n * bbox computation loop.\n *   //   Return true if the BBOX was already computed by the class and returned\n * in \"bb\" so it can be avoided to redo it again.\n *   //   Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3\n * for point clouds) template <class BBOX> bool kdtree_get_bbox(BBOX &bb) const\n *   {\n *      bb[0].low = ...; bb[0].high = ...;  // 0th dimension limits\n *      bb[1].low = ...; bb[1].high = ...;  // 1st dimension limits\n *      ...\n *      return true;\n *   }\n *\n *  \\endcode\n *\n * \\tparam DatasetAdaptor The user-provided adaptor (see comments above).\n * \\tparam Distance The distance metric to use: nanoflann::metric_L1,\n * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \\tparam DIM\n * Dimensionality of data points (e.g. 3 for 3D points) \\tparam IndexType Will\n * be typically size_t or int\n */\ntemplate <typename Distance, class DatasetAdaptor, int DIM = -1,\n          typename IndexType = size_t>\nclass KDTreeSingleIndexAdaptor\n    : public KDTreeBaseClass<\n          KDTreeSingleIndexAdaptor<Distance, DatasetAdaptor, DIM, IndexType>,\n          Distance, DatasetAdaptor, DIM, IndexType> {\npublic:\n  /** Deleted copy constructor*/\n  KDTreeSingleIndexAdaptor(\n      const KDTreeSingleIndexAdaptor<Distance, DatasetAdaptor, DIM, IndexType>\n          &) = delete;\n\n  /**\n   * The dataset used by this index\n   */\n  const DatasetAdaptor &dataset; //!< The source of our data\n\n  const KDTreeSingleIndexAdaptorParams index_params;\n\n  Distance distance;\n\n  typedef typename nanoflann::KDTreeBaseClass<\n      nanoflann::KDTreeSingleIndexAdaptor<Distance, DatasetAdaptor, DIM,\n                                          IndexType>,\n      Distance, DatasetAdaptor, DIM, IndexType>\n      BaseClassRef;\n\n  typedef typename BaseClassRef::ElementType ElementType;\n  typedef typename BaseClassRef::DistanceType DistanceType;\n\n  typedef typename BaseClassRef::Node Node;\n  typedef Node *NodePtr;\n\n  typedef typename BaseClassRef::Interval Interval;\n  /** Define \"BoundingBox\" as a fixed-size or variable-size container depending\n   * on \"DIM\" */\n  typedef typename BaseClassRef::BoundingBox BoundingBox;\n\n  /** Define \"distance_vector_t\" as a fixed-size or variable-size container\n   * depending on \"DIM\" */\n  typedef typename BaseClassRef::distance_vector_t distance_vector_t;\n\n  /**\n   * KDTree constructor\n   *\n   * Refer to docs in README.md or online in\n   * https://github.com/jlblancoc/nanoflann\n   *\n   * The KD-Tree point dimension (the length of each point in the datase, e.g. 3\n   * for 3D points) is determined by means of:\n   *  - The \\a DIM template parameter if >0 (highest priority)\n   *  - Otherwise, the \\a dimensionality parameter of this constructor.\n   *\n   * @param inputData Dataset with the input features\n   * @param params Basically, the maximum leaf node size\n   */\n  KDTreeSingleIndexAdaptor(const int dimensionality,\n                           const DatasetAdaptor &inputData,\n                           const KDTreeSingleIndexAdaptorParams &params =\n                               KDTreeSingleIndexAdaptorParams())\n      : dataset(inputData), index_params(params), distance(inputData) {\n    BaseClassRef::root_node = NULL;\n    BaseClassRef::m_size = dataset.kdtree_get_point_count();\n    BaseClassRef::m_size_at_index_build = BaseClassRef::m_size;\n    BaseClassRef::dim = dimensionality;\n    if (DIM > 0)\n      BaseClassRef::dim = DIM;\n    BaseClassRef::m_leaf_max_size = params.leaf_max_size;\n\n    // Create a permutable array of indices to the input vectors.\n    init_vind();\n  }\n\n  /**\n   * Builds the index\n   */\n  void buildIndex() {\n    BaseClassRef::m_size = dataset.kdtree_get_point_count();\n    BaseClassRef::m_size_at_index_build = BaseClassRef::m_size;\n    init_vind();\n    this->freeIndex(*this);\n    BaseClassRef::m_size_at_index_build = BaseClassRef::m_size;\n    if (BaseClassRef::m_size == 0)\n      return;\n    computeBoundingBox(BaseClassRef::root_bbox);\n    BaseClassRef::root_node =\n        this->divideTree(*this, 0, BaseClassRef::m_size,\n                         BaseClassRef::root_bbox); // construct the tree\n  }\n\n  /** \\name Query methods\n   * @{ */\n\n  /**\n   * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored\n   * inside the result object.\n   *\n   * Params:\n   *     result = the result object in which the indices of the\n   * nearest-neighbors are stored vec = the vector for which to search the\n   * nearest neighbors\n   *\n   * \\tparam RESULTSET Should be any ResultSet<DistanceType>\n   * \\return  True if the requested neighbors could be found.\n   * \\sa knnSearch, radiusSearch\n   */\n  template <typename RESULTSET>\n  bool findNeighbors(RESULTSET &result, const ElementType *vec,\n                     const SearchParams &searchParams) const {\n    assert(vec);\n    if (this->size(*this) == 0)\n      return false;\n    if (!BaseClassRef::root_node)\n      throw std::runtime_error(\n          \"[nanoflann] findNeighbors() called before building the index.\");\n    float epsError = 1 + searchParams.eps;\n\n    distance_vector_t\n        dists; // fixed or variable-sized container (depending on DIM)\n    auto zero = static_cast<decltype(result.worstDist())>(0);\n    assign(dists, (DIM > 0 ? DIM : BaseClassRef::dim),\n           zero); // Fill it with zeros.\n    DistanceType distsq = this->computeInitialDistances(*this, vec, dists);\n\n    searchLevel(result, vec, BaseClassRef::root_node, distsq, dists,\n                epsError); // \"count_leaf\" parameter removed since was neither\n                           // used nor returned to the user.\n\n    return result.full();\n  }\n\n  /**\n   * Find the \"num_closest\" nearest neighbors to the \\a query_point[0:dim-1].\n   * Their indices are stored inside the result object. \\sa radiusSearch,\n   * findNeighbors \\note nChecks_IGNORED is ignored but kept for compatibility\n   * with the original FLANN interface. \\return Number `N` of valid points in\n   * the result set. Only the first `N` entries in `out_indices` and\n   * `out_distances_sq` will be valid. Return may be less than `num_closest`\n   * only if the number of elements in the tree is less than `num_closest`.\n   */\n  size_t knnSearch(const ElementType *query_point, const size_t num_closest,\n                   IndexType *out_indices, DistanceType *out_distances_sq,\n                   const int /* nChecks_IGNORED */ = 10) const {\n    nanoflann::KNNResultSet<DistanceType, IndexType> resultSet(num_closest);\n    resultSet.init(out_indices, out_distances_sq);\n    this->findNeighbors(resultSet, query_point, nanoflann::SearchParams());\n    return resultSet.size();\n  }\n\n  /**\n   * Find all the neighbors to \\a query_point[0:dim-1] within a maximum radius.\n   *  The output is given as a vector of pairs, of which the first element is a\n   * point index and the second the corresponding distance. Previous contents of\n   * \\a IndicesDists are cleared.\n   *\n   *  If searchParams.sorted==true, the output list is sorted by ascending\n   * distances.\n   *\n   *  For a better performance, it is advisable to do a .reserve() on the vector\n   * if you have any wild guess about the number of expected matches.\n   *\n   *  \\sa knnSearch, findNeighbors, radiusSearchCustomCallback\n   * \\return The number of points within the given radius (i.e. indices.size()\n   * or dists.size() )\n   */\n  size_t\n  radiusSearch(const ElementType *query_point, const DistanceType &radius,\n               std::vector<std::pair<IndexType, DistanceType>> &IndicesDists,\n               const SearchParams &searchParams) const {\n    RadiusResultSet<DistanceType, IndexType> resultSet(radius, IndicesDists);\n    const size_t nFound =\n        radiusSearchCustomCallback(query_point, resultSet, searchParams);\n    if (searchParams.sorted)\n      std::sort(IndicesDists.begin(), IndicesDists.end(), IndexDist_Sorter());\n    return nFound;\n  }\n\n  /**\n   * Just like radiusSearch() but with a custom callback class for each point\n   * found in the radius of the query. See the source of RadiusResultSet<> as a\n   * start point for your own classes. \\sa radiusSearch\n   */\n  template <class SEARCH_CALLBACK>\n  size_t radiusSearchCustomCallback(\n      const ElementType *query_point, SEARCH_CALLBACK &resultSet,\n      const SearchParams &searchParams = SearchParams()) const {\n    this->findNeighbors(resultSet, query_point, searchParams);\n    return resultSet.size();\n  }\n\n  /** @} */\n\npublic:\n  /** Make sure the auxiliary list \\a vind has the same size than the current\n   * dataset, and re-generate if size has changed. */\n  void init_vind() {\n    // Create a permutable array of indices to the input vectors.\n    BaseClassRef::m_size = dataset.kdtree_get_point_count();\n    if (BaseClassRef::vind.size() != BaseClassRef::m_size)\n      BaseClassRef::vind.resize(BaseClassRef::m_size);\n    for (size_t i = 0; i < BaseClassRef::m_size; i++)\n      BaseClassRef::vind[i] = i;\n  }\n\n  void computeBoundingBox(BoundingBox &bbox) {\n    resize(bbox, (DIM > 0 ? DIM : BaseClassRef::dim));\n    if (dataset.kdtree_get_bbox(bbox)) {\n      // Done! It was implemented in derived class\n    } else {\n      const size_t N = dataset.kdtree_get_point_count();\n      if (!N)\n        throw std::runtime_error(\"[nanoflann] computeBoundingBox() called but \"\n                                 \"no data points found.\");\n      for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) {\n        bbox[i].low = bbox[i].high = this->dataset_get(*this, 0, i);\n      }\n      for (size_t k = 1; k < N; ++k) {\n        for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) {\n          if (this->dataset_get(*this, k, i) < bbox[i].low)\n            bbox[i].low = this->dataset_get(*this, k, i);\n          if (this->dataset_get(*this, k, i) > bbox[i].high)\n            bbox[i].high = this->dataset_get(*this, k, i);\n        }\n      }\n    }\n  }\n\n  /**\n   * Performs an exact search in the tree starting from a node.\n   * \\tparam RESULTSET Should be any ResultSet<DistanceType>\n   * \\return true if the search should be continued, false if the results are\n   * sufficient\n   */\n  template <class RESULTSET>\n  bool searchLevel(RESULTSET &result_set, const ElementType *vec,\n                   const NodePtr node, DistanceType mindistsq,\n                   distance_vector_t &dists, const float epsError) const {\n    /* If this is a leaf node, then do check and return. */\n    if ((node->child1 == NULL) && (node->child2 == NULL)) {\n      // count_leaf += (node->lr.right-node->lr.left);  // Removed since was\n      // neither used nor returned to the user.\n      DistanceType worst_dist = result_set.worstDist();\n      for (IndexType i = node->node_type.lr.left; i < node->node_type.lr.right;\n           ++i) {\n        const IndexType index = BaseClassRef::vind[i]; // reorder... : i;\n        DistanceType dist = distance.evalMetric(\n            vec, index, (DIM > 0 ? DIM : BaseClassRef::dim));\n        if (dist < worst_dist) {\n          if (!result_set.addPoint(dist, BaseClassRef::vind[i])) {\n            // the resultset doesn't want to receive any more points, we're done\n            // searching!\n            return false;\n          }\n        }\n      }\n      return true;\n    }\n\n    /* Which child branch should be taken first? */\n    int idx = node->node_type.sub.divfeat;\n    ElementType val = vec[idx];\n    DistanceType diff1 = val - node->node_type.sub.divlow;\n    DistanceType diff2 = val - node->node_type.sub.divhigh;\n\n    NodePtr bestChild;\n    NodePtr otherChild;\n    DistanceType cut_dist;\n    if ((diff1 + diff2) < 0) {\n      bestChild = node->child1;\n      otherChild = node->child2;\n      cut_dist = distance.accum_dist(val, node->node_type.sub.divhigh, idx);\n    } else {\n      bestChild = node->child2;\n      otherChild = node->child1;\n      cut_dist = distance.accum_dist(val, node->node_type.sub.divlow, idx);\n    }\n\n    /* Call recursively to search next level down. */\n    if (!searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError)) {\n      // the resultset doesn't want to receive any more points, we're done\n      // searching!\n      return false;\n    }\n\n    DistanceType dst = dists[idx];\n    mindistsq = mindistsq + cut_dist - dst;\n    dists[idx] = cut_dist;\n    if (mindistsq * epsError <= result_set.worstDist()) {\n      if (!searchLevel(result_set, vec, otherChild, mindistsq, dists,\n                       epsError)) {\n        // the resultset doesn't want to receive any more points, we're done\n        // searching!\n        return false;\n      }\n    }\n    dists[idx] = dst;\n    return true;\n  }\n\npublic:\n  /**  Stores the index in a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so when\n   * loading the index object it must be constructed associated to the same\n   * source of data points used while building it. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); }\n\n  /**  Loads a previous index from a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so the\n   * index object must be constructed associated to the same source of data\n   * points used while building the index. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); }\n\n}; // class KDTree\n\n/** kd-tree dynamic index\n *\n * Contains the k-d trees and other information for indexing a set of points\n * for nearest-neighbor matching.\n *\n *  The class \"DatasetAdaptor\" must provide the following interface (can be\n * non-virtual, inlined methods):\n *\n *  \\code\n *   // Must return the number of data poins\n *   inline size_t kdtree_get_point_count() const { ... }\n *\n *   // Must return the dim'th component of the idx'th point in the class:\n *   inline T kdtree_get_pt(const size_t idx, const size_t dim) const { ... }\n *\n *   // Optional bounding-box computation: return false to default to a standard\n * bbox computation loop.\n *   //   Return true if the BBOX was already computed by the class and returned\n * in \"bb\" so it can be avoided to redo it again.\n *   //   Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3\n * for point clouds) template <class BBOX> bool kdtree_get_bbox(BBOX &bb) const\n *   {\n *      bb[0].low = ...; bb[0].high = ...;  // 0th dimension limits\n *      bb[1].low = ...; bb[1].high = ...;  // 1st dimension limits\n *      ...\n *      return true;\n *   }\n *\n *  \\endcode\n *\n * \\tparam DatasetAdaptor The user-provided adaptor (see comments above).\n * \\tparam Distance The distance metric to use: nanoflann::metric_L1,\n * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \\tparam DIM\n * Dimensionality of data points (e.g. 3 for 3D points) \\tparam IndexType Will\n * be typically size_t or int\n */\ntemplate <typename Distance, class DatasetAdaptor, int DIM = -1,\n          typename IndexType = size_t>\nclass KDTreeSingleIndexDynamicAdaptor_\n    : public KDTreeBaseClass<KDTreeSingleIndexDynamicAdaptor_<\n                                 Distance, DatasetAdaptor, DIM, IndexType>,\n                             Distance, DatasetAdaptor, DIM, IndexType> {\npublic:\n  /**\n   * The dataset used by this index\n   */\n  const DatasetAdaptor &dataset; //!< The source of our data\n\n  KDTreeSingleIndexAdaptorParams index_params;\n\n  std::vector<int> &treeIndex;\n\n  Distance distance;\n\n  typedef typename nanoflann::KDTreeBaseClass<\n      nanoflann::KDTreeSingleIndexDynamicAdaptor_<Distance, DatasetAdaptor, DIM,\n                                                  IndexType>,\n      Distance, DatasetAdaptor, DIM, IndexType>\n      BaseClassRef;\n\n  typedef typename BaseClassRef::ElementType ElementType;\n  typedef typename BaseClassRef::DistanceType DistanceType;\n\n  typedef typename BaseClassRef::Node Node;\n  typedef Node *NodePtr;\n\n  typedef typename BaseClassRef::Interval Interval;\n  /** Define \"BoundingBox\" as a fixed-size or variable-size container depending\n   * on \"DIM\" */\n  typedef typename BaseClassRef::BoundingBox BoundingBox;\n\n  /** Define \"distance_vector_t\" as a fixed-size or variable-size container\n   * depending on \"DIM\" */\n  typedef typename BaseClassRef::distance_vector_t distance_vector_t;\n\n  /**\n   * KDTree constructor\n   *\n   * Refer to docs in README.md or online in\n   * https://github.com/jlblancoc/nanoflann\n   *\n   * The KD-Tree point dimension (the length of each point in the datase, e.g. 3\n   * for 3D points) is determined by means of:\n   *  - The \\a DIM template parameter if >0 (highest priority)\n   *  - Otherwise, the \\a dimensionality parameter of this constructor.\n   *\n   * @param inputData Dataset with the input features\n   * @param params Basically, the maximum leaf node size\n   */\n  KDTreeSingleIndexDynamicAdaptor_(\n      const int dimensionality, const DatasetAdaptor &inputData,\n      std::vector<int> &treeIndex_,\n      const KDTreeSingleIndexAdaptorParams &params =\n          KDTreeSingleIndexAdaptorParams())\n      : dataset(inputData), index_params(params), treeIndex(treeIndex_),\n        distance(inputData) {\n    BaseClassRef::root_node = NULL;\n    BaseClassRef::m_size = 0;\n    BaseClassRef::m_size_at_index_build = 0;\n    BaseClassRef::dim = dimensionality;\n    if (DIM > 0)\n      BaseClassRef::dim = DIM;\n    BaseClassRef::m_leaf_max_size = params.leaf_max_size;\n  }\n\n  /** Assignment operator definiton */\n  KDTreeSingleIndexDynamicAdaptor_\n  operator=(const KDTreeSingleIndexDynamicAdaptor_ &rhs) {\n    KDTreeSingleIndexDynamicAdaptor_ tmp(rhs);\n    std::swap(BaseClassRef::vind, tmp.BaseClassRef::vind);\n    std::swap(BaseClassRef::m_leaf_max_size, tmp.BaseClassRef::m_leaf_max_size);\n    std::swap(index_params, tmp.index_params);\n    std::swap(treeIndex, tmp.treeIndex);\n    std::swap(BaseClassRef::m_size, tmp.BaseClassRef::m_size);\n    std::swap(BaseClassRef::m_size_at_index_build,\n              tmp.BaseClassRef::m_size_at_index_build);\n    std::swap(BaseClassRef::root_node, tmp.BaseClassRef::root_node);\n    std::swap(BaseClassRef::root_bbox, tmp.BaseClassRef::root_bbox);\n    std::swap(BaseClassRef::pool, tmp.BaseClassRef::pool);\n    return *this;\n  }\n\n  /**\n   * Builds the index\n   */\n  void buildIndex() {\n    BaseClassRef::m_size = BaseClassRef::vind.size();\n    this->freeIndex(*this);\n    BaseClassRef::m_size_at_index_build = BaseClassRef::m_size;\n    if (BaseClassRef::m_size == 0)\n      return;\n    computeBoundingBox(BaseClassRef::root_bbox);\n    BaseClassRef::root_node =\n        this->divideTree(*this, 0, BaseClassRef::m_size,\n                         BaseClassRef::root_bbox); // construct the tree\n  }\n\n  /** \\name Query methods\n   * @{ */\n\n  /**\n   * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored\n   * inside the result object.\n   *\n   * Params:\n   *     result = the result object in which the indices of the\n   * nearest-neighbors are stored vec = the vector for which to search the\n   * nearest neighbors\n   *\n   * \\tparam RESULTSET Should be any ResultSet<DistanceType>\n   * \\return  True if the requested neighbors could be found.\n   * \\sa knnSearch, radiusSearch\n   */\n  template <typename RESULTSET>\n  bool findNeighbors(RESULTSET &result, const ElementType *vec,\n                     const SearchParams &searchParams) const {\n    assert(vec);\n    if (this->size(*this) == 0)\n      return false;\n    if (!BaseClassRef::root_node)\n      return false;\n    float epsError = 1 + searchParams.eps;\n\n    // fixed or variable-sized container (depending on DIM)\n    distance_vector_t dists;\n    // Fill it with zeros.\n    assign(dists, (DIM > 0 ? DIM : BaseClassRef::dim),\n           static_cast<typename distance_vector_t::value_type>(0));\n    DistanceType distsq = this->computeInitialDistances(*this, vec, dists);\n\n    searchLevel(result, vec, BaseClassRef::root_node, distsq, dists,\n                epsError); // \"count_leaf\" parameter removed since was neither\n                           // used nor returned to the user.\n\n    return result.full();\n  }\n\n  /**\n   * Find the \"num_closest\" nearest neighbors to the \\a query_point[0:dim-1].\n   * Their indices are stored inside the result object. \\sa radiusSearch,\n   * findNeighbors \\note nChecks_IGNORED is ignored but kept for compatibility\n   * with the original FLANN interface. \\return Number `N` of valid points in\n   * the result set. Only the first `N` entries in `out_indices` and\n   * `out_distances_sq` will be valid. Return may be less than `num_closest`\n   * only if the number of elements in the tree is less than `num_closest`.\n   */\n  size_t knnSearch(const ElementType *query_point, const size_t num_closest,\n                   IndexType *out_indices, DistanceType *out_distances_sq,\n                   const int /* nChecks_IGNORED */ = 10) const {\n    nanoflann::KNNResultSet<DistanceType, IndexType> resultSet(num_closest);\n    resultSet.init(out_indices, out_distances_sq);\n    this->findNeighbors(resultSet, query_point, nanoflann::SearchParams());\n    return resultSet.size();\n  }\n\n  /**\n   * Find all the neighbors to \\a query_point[0:dim-1] within a maximum radius.\n   *  The output is given as a vector of pairs, of which the first element is a\n   * point index and the second the corresponding distance. Previous contents of\n   * \\a IndicesDists are cleared.\n   *\n   *  If searchParams.sorted==true, the output list is sorted by ascending\n   * distances.\n   *\n   *  For a better performance, it is advisable to do a .reserve() on the vector\n   * if you have any wild guess about the number of expected matches.\n   *\n   *  \\sa knnSearch, findNeighbors, radiusSearchCustomCallback\n   * \\return The number of points within the given radius (i.e. indices.size()\n   * or dists.size() )\n   */\n  size_t\n  radiusSearch(const ElementType *query_point, const DistanceType &radius,\n               std::vector<std::pair<IndexType, DistanceType>> &IndicesDists,\n               const SearchParams &searchParams) const {\n    RadiusResultSet<DistanceType, IndexType> resultSet(radius, IndicesDists);\n    const size_t nFound =\n        radiusSearchCustomCallback(query_point, resultSet, searchParams);\n    if (searchParams.sorted)\n      std::sort(IndicesDists.begin(), IndicesDists.end(), IndexDist_Sorter());\n    return nFound;\n  }\n\n  /**\n   * Just like radiusSearch() but with a custom callback class for each point\n   * found in the radius of the query. See the source of RadiusResultSet<> as a\n   * start point for your own classes. \\sa radiusSearch\n   */\n  template <class SEARCH_CALLBACK>\n  size_t radiusSearchCustomCallback(\n      const ElementType *query_point, SEARCH_CALLBACK &resultSet,\n      const SearchParams &searchParams = SearchParams()) const {\n    this->findNeighbors(resultSet, query_point, searchParams);\n    return resultSet.size();\n  }\n\n  /** @} */\n\npublic:\n  void computeBoundingBox(BoundingBox &bbox) {\n    resize(bbox, (DIM > 0 ? DIM : BaseClassRef::dim));\n\n    if (dataset.kdtree_get_bbox(bbox)) {\n      // Done! It was implemented in derived class\n    } else {\n      const size_t N = BaseClassRef::m_size;\n      if (!N)\n        throw std::runtime_error(\"[nanoflann] computeBoundingBox() called but \"\n                                 \"no data points found.\");\n      for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) {\n        bbox[i].low = bbox[i].high =\n            this->dataset_get(*this, BaseClassRef::vind[0], i);\n      }\n      for (size_t k = 1; k < N; ++k) {\n        for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) {\n          if (this->dataset_get(*this, BaseClassRef::vind[k], i) < bbox[i].low)\n            bbox[i].low = this->dataset_get(*this, BaseClassRef::vind[k], i);\n          if (this->dataset_get(*this, BaseClassRef::vind[k], i) > bbox[i].high)\n            bbox[i].high = this->dataset_get(*this, BaseClassRef::vind[k], i);\n        }\n      }\n    }\n  }\n\n  /**\n   * Performs an exact search in the tree starting from a node.\n   * \\tparam RESULTSET Should be any ResultSet<DistanceType>\n   */\n  template <class RESULTSET>\n  void searchLevel(RESULTSET &result_set, const ElementType *vec,\n                   const NodePtr node, DistanceType mindistsq,\n                   distance_vector_t &dists, const float epsError) const {\n    /* If this is a leaf node, then do check and return. */\n    if ((node->child1 == NULL) && (node->child2 == NULL)) {\n      // count_leaf += (node->lr.right-node->lr.left);  // Removed since was\n      // neither used nor returned to the user.\n      DistanceType worst_dist = result_set.worstDist();\n      for (IndexType i = node->node_type.lr.left; i < node->node_type.lr.right;\n           ++i) {\n        const IndexType index = BaseClassRef::vind[i]; // reorder... : i;\n        if (treeIndex[index] == -1)\n          continue;\n        DistanceType dist = distance.evalMetric(\n            vec, index, (DIM > 0 ? DIM : BaseClassRef::dim));\n        if (dist < worst_dist) {\n          if (!result_set.addPoint(\n                  static_cast<typename RESULTSET::DistanceType>(dist),\n                  static_cast<typename RESULTSET::IndexType>(\n                      BaseClassRef::vind[i]))) {\n            // the resultset doesn't want to receive any more points, we're done\n            // searching!\n            return; // false;\n          }\n        }\n      }\n      return;\n    }\n\n    /* Which child branch should be taken first? */\n    int idx = node->node_type.sub.divfeat;\n    ElementType val = vec[idx];\n    DistanceType diff1 = val - node->node_type.sub.divlow;\n    DistanceType diff2 = val - node->node_type.sub.divhigh;\n\n    NodePtr bestChild;\n    NodePtr otherChild;\n    DistanceType cut_dist;\n    if ((diff1 + diff2) < 0) {\n      bestChild = node->child1;\n      otherChild = node->child2;\n      cut_dist = distance.accum_dist(val, node->node_type.sub.divhigh, idx);\n    } else {\n      bestChild = node->child2;\n      otherChild = node->child1;\n      cut_dist = distance.accum_dist(val, node->node_type.sub.divlow, idx);\n    }\n\n    /* Call recursively to search next level down. */\n    searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);\n\n    DistanceType dst = dists[idx];\n    mindistsq = mindistsq + cut_dist - dst;\n    dists[idx] = cut_dist;\n    if (mindistsq * epsError <= result_set.worstDist()) {\n      searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);\n    }\n    dists[idx] = dst;\n  }\n\npublic:\n  /**  Stores the index in a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so when\n   * loading the index object it must be constructed associated to the same\n   * source of data points used while building it. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); }\n\n  /**  Loads a previous index from a binary file.\n   *   IMPORTANT NOTE: The set of data points is NOT stored in the file, so the\n   * index object must be constructed associated to the same source of data\n   * points used while building the index. See the example:\n   * examples/saveload_example.cpp \\sa loadIndex  */\n  void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); }\n};\n\n/** kd-tree dynaimic index\n *\n * class to create multiple static index and merge their results to behave as\n * single dynamic index as proposed in Logarithmic Approach.\n *\n *  Example of usage:\n *  examples/dynamic_pointcloud_example.cpp\n *\n * \\tparam DatasetAdaptor The user-provided adaptor (see comments above).\n * \\tparam Distance The distance metric to use: nanoflann::metric_L1,\n * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \\tparam DIM\n * Dimensionality of data points (e.g. 3 for 3D points) \\tparam IndexType Will\n * be typically size_t or int\n */\ntemplate <typename Distance, class DatasetAdaptor, int DIM = -1,\n          typename IndexType = size_t>\nclass KDTreeSingleIndexDynamicAdaptor {\npublic:\n  typedef typename Distance::ElementType ElementType;\n  typedef typename Distance::DistanceType DistanceType;\n\nprotected:\n  size_t m_leaf_max_size;\n  size_t treeCount;\n  size_t pointCount;\n\n  /**\n   * The dataset used by this index\n   */\n  const DatasetAdaptor &dataset; //!< The source of our data\n\n  std::vector<int> treeIndex; //!< treeIndex[idx] is the index of tree in which\n                              //!< point at idx is stored. treeIndex[idx]=-1\n                              //!< means that point has been removed.\n\n  KDTreeSingleIndexAdaptorParams index_params;\n\n  int dim; //!< Dimensionality of each data point\n\n  typedef KDTreeSingleIndexDynamicAdaptor_<Distance, DatasetAdaptor, DIM>\n      index_container_t;\n  std::vector<index_container_t> index;\n\npublic:\n  /** Get a const ref to the internal list of indices; the number of indices is\n   * adapted dynamically as the dataset grows in size. */\n  const std::vector<index_container_t> &getAllIndices() const { return index; }\n\nprivate:\n  /** finds position of least significant unset bit */\n  int First0Bit(IndexType num) {\n    int pos = 0;\n    while (num & 1) {\n      num = num >> 1;\n      pos++;\n    }\n    return pos;\n  }\n\n  /** Creates multiple empty trees to handle dynamic support */\n  void init() {\n    typedef KDTreeSingleIndexDynamicAdaptor_<Distance, DatasetAdaptor, DIM>\n        my_kd_tree_t;\n    std::vector<my_kd_tree_t> index_(\n        treeCount, my_kd_tree_t(dim /*dim*/, dataset, treeIndex, index_params));\n    index = index_;\n  }\n\npublic:\n  Distance distance;\n\n  /**\n   * KDTree constructor\n   *\n   * Refer to docs in README.md or online in\n   * https://github.com/jlblancoc/nanoflann\n   *\n   * The KD-Tree point dimension (the length of each point in the datase, e.g. 3\n   * for 3D points) is determined by means of:\n   *  - The \\a DIM template parameter if >0 (highest priority)\n   *  - Otherwise, the \\a dimensionality parameter of this constructor.\n   *\n   * @param inputData Dataset with the input features\n   * @param params Basically, the maximum leaf node size\n   */\n  KDTreeSingleIndexDynamicAdaptor(const int dimensionality,\n                                  const DatasetAdaptor &inputData,\n                                  const KDTreeSingleIndexAdaptorParams &params =\n                                      KDTreeSingleIndexAdaptorParams(),\n                                  const size_t maximumPointCount = 1000000000U)\n      : dataset(inputData), index_params(params), distance(inputData) {\n    treeCount = static_cast<size_t>(std::log2(maximumPointCount));\n    pointCount = 0U;\n    dim = dimensionality;\n    treeIndex.clear();\n    if (DIM > 0)\n      dim = DIM;\n    m_leaf_max_size = params.leaf_max_size;\n    init();\n    const size_t num_initial_points = dataset.kdtree_get_point_count();\n    if (num_initial_points > 0) {\n      addPoints(0, num_initial_points - 1);\n    }\n  }\n\n  /** Deleted copy constructor*/\n  KDTreeSingleIndexDynamicAdaptor(\n      const KDTreeSingleIndexDynamicAdaptor<Distance, DatasetAdaptor, DIM,\n                                            IndexType> &) = delete;\n\n  /** Add points to the set, Inserts all points from [start, end] */\n  void addPoints(IndexType start, IndexType end) {\n    size_t count = end - start + 1;\n    treeIndex.resize(treeIndex.size() + count);\n    for (IndexType idx = start; idx <= end; idx++) {\n      int pos = First0Bit(pointCount);\n      index[pos].vind.clear();\n      treeIndex[pointCount] = pos;\n      for (int i = 0; i < pos; i++) {\n        for (int j = 0; j < static_cast<int>(index[i].vind.size()); j++) {\n          index[pos].vind.push_back(index[i].vind[j]);\n          if (treeIndex[index[i].vind[j]] != -1)\n            treeIndex[index[i].vind[j]] = pos;\n        }\n        index[i].vind.clear();\n        index[i].freeIndex(index[i]);\n      }\n      index[pos].vind.push_back(idx);\n      index[pos].buildIndex();\n      pointCount++;\n    }\n  }\n\n  /** Remove a point from the set (Lazy Deletion) */\n  void removePoint(size_t idx) {\n    if (idx >= pointCount)\n      return;\n    treeIndex[idx] = -1;\n  }\n\n  /**\n   * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored\n   * inside the result object.\n   *\n   * Params:\n   *     result = the result object in which the indices of the\n   * nearest-neighbors are stored vec = the vector for which to search the\n   * nearest neighbors\n   *\n   * \\tparam RESULTSET Should be any ResultSet<DistanceType>\n   * \\return  True if the requested neighbors could be found.\n   * \\sa knnSearch, radiusSearch\n   */\n  template <typename RESULTSET>\n  bool findNeighbors(RESULTSET &result, const ElementType *vec,\n                     const SearchParams &searchParams) const {\n    for (size_t i = 0; i < treeCount; i++) {\n      index[i].findNeighbors(result, &vec[0], searchParams);\n    }\n    return result.full();\n  }\n};\n\n/** An L2-metric KD-tree adaptor for working with data directly stored in an\n * Eigen Matrix, without duplicating the data storage. Each row in the matrix\n * represents a point in the state space.\n *\n *  Example of usage:\n * \\code\n * \tEigen::Matrix<num_t,Dynamic,Dynamic>  mat;\n * \t// Fill out \"mat\"...\n *\n * \ttypedef KDTreeEigenMatrixAdaptor< Eigen::Matrix<num_t,Dynamic,Dynamic> >\n * my_kd_tree_t; const int max_leaf = 10; my_kd_tree_t   mat_index(mat, max_leaf\n * ); mat_index.index->buildIndex(); mat_index.index->... \\endcode\n *\n *  \\tparam DIM If set to >0, it specifies a compile-time fixed dimensionality\n * for the points in the data set, allowing more compiler optimizations. \\tparam\n * Distance The distance metric to use: nanoflann::metric_L1,\n * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc.\n */\ntemplate <class MatrixType, int DIM = -1, class Distance = nanoflann::metric_L2>\nstruct KDTreeEigenMatrixAdaptor {\n  typedef KDTreeEigenMatrixAdaptor<MatrixType, DIM, Distance> self_t;\n  typedef typename MatrixType::Scalar num_t;\n  typedef typename MatrixType::Index IndexType;\n  typedef\n      typename Distance::template traits<num_t, self_t>::distance_t metric_t;\n  typedef KDTreeSingleIndexAdaptor<metric_t, self_t,\n                                   MatrixType::ColsAtCompileTime, IndexType>\n      index_t;\n\n  index_t *index; //! The kd-tree index for the user to call its methods as\n                  //! usual with any other FLANN index.\n\n  /// Constructor: takes a const ref to the matrix object with the data points\n  KDTreeEigenMatrixAdaptor(const size_t dimensionality,\n                           const std::reference_wrapper<const MatrixType> &mat,\n                           const int leaf_max_size = 10)\n      : m_data_matrix(mat) {\n    const auto dims = mat.get().cols();\n    if (size_t(dims) != dimensionality)\n      throw std::runtime_error(\n          \"Error: 'dimensionality' must match column count in data matrix\");\n    if (DIM > 0 && int(dims) != DIM)\n      throw std::runtime_error(\n          \"Data set dimensionality does not match the 'DIM' template argument\");\n    index =\n        new index_t(static_cast<int>(dims), *this /* adaptor */,\n                    nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));\n    index->buildIndex();\n  }\n\npublic:\n  /** Deleted copy constructor */\n  KDTreeEigenMatrixAdaptor(const self_t &) = delete;\n\n  ~KDTreeEigenMatrixAdaptor() { delete index; }\n\n  const std::reference_wrapper<const MatrixType> m_data_matrix;\n\n  /** Query for the \\a num_closest closest points to a given point (entered as\n   * query_point[0:dim-1]). Note that this is a short-cut method for\n   * index->findNeighbors(). The user can also call index->... methods as\n   * desired. \\note nChecks_IGNORED is ignored but kept for compatibility with\n   * the original FLANN interface.\n   */\n  inline void query(const num_t *query_point, const size_t num_closest,\n                    IndexType *out_indices, num_t *out_distances_sq,\n                    const int /* nChecks_IGNORED */ = 10) const {\n    nanoflann::KNNResultSet<num_t, IndexType> resultSet(num_closest);\n    resultSet.init(out_indices, out_distances_sq);\n    index->findNeighbors(resultSet, query_point, nanoflann::SearchParams());\n  }\n\n  /** @name Interface expected by KDTreeSingleIndexAdaptor\n   * @{ */\n\n  const self_t &derived() const { return *this; }\n  self_t &derived() { return *this; }\n\n  // Must return the number of data points\n  inline size_t kdtree_get_point_count() const {\n    return m_data_matrix.get().rows();\n  }\n\n  // Returns the dim'th component of the idx'th point in the class:\n  inline num_t kdtree_get_pt(const IndexType idx, size_t dim) const {\n    return m_data_matrix.get().coeff(idx, IndexType(dim));\n  }\n\n  // Optional bounding-box computation: return false to default to a standard\n  // bbox computation loop.\n  //   Return true if the BBOX was already computed by the class and returned in\n  //   \"bb\" so it can be avoided to redo it again. Look at bb.size() to find out\n  //   the expected dimensionality (e.g. 2 or 3 for point clouds)\n  template <class BBOX> bool kdtree_get_bbox(BBOX & /*bb*/) const {\n    return false;\n  }\n\n  /** @} */\n\n}; // end of KDTreeEigenMatrixAdaptor\n   /** @} */\n\n/** @} */ // end of grouping\n} // namespace nanoflann\n\n#endif /* NANOFLANN_HPP_ */\n"
  },
  {
    "path": "thirdparty/kpconv/kernels/kernel_points.py",
    "content": "\n#\n#\n#      0=================================0\n#      |    Kernel Point Convolutions    |\n#      0=================================0\n#\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      Functions handling the disposition of kernel points.\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      Hugues THOMAS - 11/06/2018\n#\n\nimport time\nimport numpy as np\nfrom os import makedirs\nfrom os.path import join, exists\n# from lib.ply import read_ply, write_ply\nfrom kpconv.lib.ply import read_ply, write_ply\n\n\n# ------------------------------------------------------------------------------------------\n#\n#           Functions\n#       \\***************/\n#\n#\n\ndef create_3D_rotations(axis, angle):\n    \"\"\"\n    Create rotation matrices from a list of axes and angles. Code from wikipedia on quaternions\n    :param axis: float32[N, 3]\n    :param angle: float32[N,]\n    :return: float32[N, 3, 3]\n    \"\"\"\n\n    t1 = np.cos(angle)\n    t2 = 1 - t1\n    t3 = axis[:, 0] * axis[:, 0]\n    t6 = t2 * axis[:, 0]\n    t7 = t6 * axis[:, 1]\n    t8 = np.sin(angle)\n    t9 = t8 * axis[:, 2]\n    t11 = t6 * axis[:, 2]\n    t12 = t8 * axis[:, 1]\n    t15 = axis[:, 1] * axis[:, 1]\n    t19 = t2 * axis[:, 1] * axis[:, 2]\n    t20 = t8 * axis[:, 0]\n    t24 = axis[:, 2] * axis[:, 2]\n    R = np.stack([t1 + t2 * t3,\n                  t7 - t9,\n                  t11 + t12,\n                  t7 + t9,\n                  t1 + t2 * t15,\n                  t19 - t20,\n                  t11 - t12,\n                  t19 + t20,\n                  t1 + t2 * t24], axis=1)\n\n    return np.reshape(R, (-1, 3, 3))\n\n\ndef spherical_Lloyd(radius, num_cells, dimension=3, fixed='center', approximation='monte-carlo',\n                    approx_n=5000, max_iter=500, momentum=0.9, verbose=0):\n    \"\"\"\n    Creation of kernel point via Lloyd algorithm. We use an approximation of the algorithm, and compute the Voronoi\n    cell centers with discretization  of space. The exact formula is not trivial with part of the sphere as sides.\n    :param radius: Radius of the kernels\n    :param num_cells: Number of cell (kernel points) in the Voronoi diagram.\n    :param dimension: dimension of the space\n    :param fixed: fix position of certain kernel points ('none', 'center' or 'verticals')\n    :param approximation: Approximation method for Lloyd's algorithm ('discretization', 'monte-carlo')\n    :param approx_n: Number of point used for approximation.\n    :param max_iter: Maximum nu;ber of iteration for the algorithm.\n    :param momentum: Momentum of the low pass filter smoothing kernel point positions\n    :param verbose: display option\n    :return: points [num_kernels, num_points, dimension]\n    \"\"\"\n\n    #######################\n    # Parameters definition\n    #######################\n\n    # Radius used for optimization (points are rescaled afterwards)\n    radius0 = 1.0\n\n    #######################\n    # Kernel initialization\n    #######################\n\n    # Random kernel points (Uniform distribution in a sphere)\n    kernel_points = np.zeros((0, dimension))\n    while kernel_points.shape[0] < num_cells:\n        new_points = np.random.rand(num_cells, dimension) * 2 * radius0 - radius0\n        kernel_points = np.vstack((kernel_points, new_points))\n        d2 = np.sum(np.power(kernel_points, 2), axis=1)\n        kernel_points = kernel_points[np.logical_and(d2 < radius0 ** 2, (0.9 * radius0) ** 2 < d2), :]\n    kernel_points = kernel_points[:num_cells, :].reshape((num_cells, -1))\n\n    # Optional fixing\n    if fixed == 'center':\n        kernel_points[0, :] *= 0\n    if fixed == 'verticals':\n        kernel_points[:3, :] *= 0\n        kernel_points[1, -1] += 2 * radius0 / 3\n        kernel_points[2, -1] -= 2 * radius0 / 3\n\n    ##############################\n    # Approximation initialization\n    ##############################\n\n    # Initialize figure\n    if verbose > 1:\n        fig = plt.figure()\n\n    # Initialize discretization in this method is chosen\n    if approximation == 'discretization':\n        side_n = int(np.floor(approx_n ** (1. / dimension)))\n        dl = 2 * radius0 / side_n\n        coords = np.arange(-radius0 + dl/2, radius0, dl)\n        if dimension == 2:\n            x, y = np.meshgrid(coords, coords)\n            X = np.vstack((np.ravel(x), np.ravel(y))).T\n        elif dimension == 3:\n            x, y, z = np.meshgrid(coords, coords, coords)\n            X = np.vstack((np.ravel(x), np.ravel(y), np.ravel(z))).T\n        elif dimension == 4:\n            x, y, z, t = np.meshgrid(coords, coords, coords, coords)\n            X = np.vstack((np.ravel(x), np.ravel(y), np.ravel(z), np.ravel(t))).T\n        else:\n            raise ValueError('Unsupported dimension (max is 4)')\n    elif approximation == 'monte-carlo':\n        X = np.zeros((0, dimension))\n    else:\n        raise ValueError('Wrong approximation method chosen: \"{:s}\"'.format(approximation))\n\n    # Only points inside the sphere are used\n    d2 = np.sum(np.power(X, 2), axis=1)\n    X = X[d2 < radius0 * radius0, :]\n\n    #####################\n    # Kernel optimization\n    #####################\n\n    # Warning if at least one kernel point has no cell\n    warning = False\n\n    # moving vectors of kernel points saved to detect convergence\n    max_moves = np.zeros((0,))\n\n    for iter in range(max_iter):\n\n        # In the case of monte-carlo, renew the sampled points\n        if approximation == 'monte-carlo':\n            X = np.random.rand(approx_n, dimension) * 2 * radius0 - radius0\n            d2 = np.sum(np.power(X, 2), axis=1)\n            X = X[d2 < radius0 * radius0, :]\n\n        # Get the distances matrix [n_approx, K, dim]\n        differences = np.expand_dims(X, 1) - kernel_points\n        sq_distances = np.sum(np.square(differences), axis=2)\n\n        # Compute cell centers\n        cell_inds = np.argmin(sq_distances, axis=1)\n        centers = []\n        for c in range(num_cells):\n            bool_c = (cell_inds == c)\n            num_c = np.sum(bool_c.astype(np.int32))\n            if num_c > 0:\n                centers.append(np.sum(X[bool_c, :], axis=0) / num_c)\n            else:\n                warning = True\n                centers.append(kernel_points[c])\n\n        # Update kernel points with low pass filter to smooth mote carlo\n        centers = np.vstack(centers)\n        moves = (1 - momentum) * (centers - kernel_points)\n        kernel_points += moves\n\n        # Check moves for convergence\n        max_moves = np.append(max_moves, np.max(np.linalg.norm(moves, axis=1)))\n\n        # Optional fixing\n        if fixed == 'center':\n            kernel_points[0, :] *= 0\n        if fixed == 'verticals':\n            kernel_points[0, :] *= 0\n            kernel_points[:3, :-1] *= 0\n\n        if verbose:\n            print('iter {:5d} / max move = {:f}'.format(iter, np.max(np.linalg.norm(moves, axis=1))))\n            if warning:\n                print('{:}WARNING: at least one point has no cell{:}'.format(bcolors.WARNING, bcolors.ENDC))\n        if verbose > 1:\n            plt.clf()\n            plt.scatter(X[:, 0], X[:, 1], c=cell_inds, s=20.0,\n                        marker='.', cmap=plt.get_cmap('tab20'))\n            #plt.scatter(kernel_points[:, 0], kernel_points[:, 1], c=np.arange(num_cells), s=100.0,\n            #            marker='+', cmap=plt.get_cmap('tab20'))\n            plt.plot(kernel_points[:, 0], kernel_points[:, 1], 'k+')\n            circle = plt.Circle((0, 0), radius0, color='r', fill=False)\n            fig.axes[0].add_artist(circle)\n            fig.axes[0].set_xlim((-radius0 * 1.1, radius0 * 1.1))\n            fig.axes[0].set_ylim((-radius0 * 1.1, radius0 * 1.1))\n            fig.axes[0].set_aspect('equal')\n            plt.draw()\n            plt.pause(0.001)\n            plt.show(block=False)\n\n    ###################\n    # User verification\n    ###################\n\n    # Show the convergence to ask user if this kernel is correct\n    if verbose:\n        if dimension == 2:\n            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[10.4, 4.8])\n            ax1.plot(max_moves)\n            ax2.scatter(X[:, 0], X[:, 1], c=cell_inds, s=20.0,\n                        marker='.', cmap=plt.get_cmap('tab20'))\n            # plt.scatter(kernel_points[:, 0], kernel_points[:, 1], c=np.arange(num_cells), s=100.0,\n            #            marker='+', cmap=plt.get_cmap('tab20'))\n            ax2.plot(kernel_points[:, 0], kernel_points[:, 1], 'k+')\n            circle = plt.Circle((0, 0), radius0, color='r', fill=False)\n            ax2.add_artist(circle)\n            ax2.set_xlim((-radius0 * 1.1, radius0 * 1.1))\n            ax2.set_ylim((-radius0 * 1.1, radius0 * 1.1))\n            ax2.set_aspect('equal')\n            plt.title('Check if kernel is correct.')\n            plt.draw()\n            plt.show()\n\n        if dimension > 2:\n            plt.figure()\n            plt.plot(max_moves)\n            plt.title('Check if kernel is correct.')\n            plt.show()\n\n    # Rescale kernels with real radius\n    return kernel_points * radius\n\n\ndef kernel_point_optimization_debug(radius, num_points, num_kernels=1, dimension=3,\n                                    fixed='center', ratio=0.66, verbose=0):\n    \"\"\"\n    Creation of kernel point via optimization of potentials.\n    :param radius: Radius of the kernels\n    :param num_points: points composing kernels\n    :param num_kernels: number of wanted kernels\n    :param dimension: dimension of the space\n    :param fixed: fix position of certain kernel points ('none', 'center' or 'verticals')\n    :param ratio: ratio of the radius where you want the kernels points to be placed\n    :param verbose: display option\n    :return: points [num_kernels, num_points, dimension]\n    \"\"\"\n\n    #######################\n    # Parameters definition\n    #######################\n\n    # Radius used for optimization (points are rescaled afterwards)\n    radius0 = 1\n    diameter0 = 2\n\n    # Factor multiplicating gradients for moving points (~learning rate)\n    moving_factor = 1e-2\n    continuous_moving_decay = 0.9995\n\n    # Gradient threshold to stop optimization\n    thresh = 1e-5\n\n    # Gradient clipping value\n    clip = 0.05 * radius0\n\n    #######################\n    # Kernel initialization\n    #######################\n    # import pdb \n    # pdb.set_trace()\n\n    # Random kernel points\n    kernel_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0\n    while (kernel_points.shape[0] < num_kernels * num_points):\n        new_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0\n        kernel_points = np.vstack((kernel_points, new_points))\n        d2 = np.sum(np.power(kernel_points, 2), axis=1)\n        kernel_points = kernel_points[d2 < 0.5 * radius0 * radius0, :]\n    kernel_points = kernel_points[:num_kernels * num_points, :].reshape((num_kernels, num_points, -1))\n\n    # Optionnal fixing\n    if fixed == 'center':\n        kernel_points[:, 0, :] *= 0\n    if fixed == 'verticals':\n        kernel_points[:, :3, :] *= 0\n        kernel_points[:, 1, -1] += 2 * radius0 / 3\n        kernel_points[:, 2, -1] -= 2 * radius0 / 3\n\n    #####################\n    # Kernel optimization\n    #####################\n\n    # Initialize figure\n    if verbose>1:\n        fig = plt.figure()\n\n    saved_gradient_norms = np.zeros((10000, num_kernels))\n    old_gradient_norms = np.zeros((num_kernels, num_points))\n    for iter in range(10000):\n\n        # Compute gradients\n        # *****************\n\n        # Derivative of the sum of potentials of all points\n        A = np.expand_dims(kernel_points, axis=2)\n        B = np.expand_dims(kernel_points, axis=1)\n        interd2 = np.sum(np.power(A - B, 2), axis=-1)\n        inter_grads = (A - B) / (np.power(np.expand_dims(interd2, -1), 3/2) + 1e-6)\n        inter_grads = np.sum(inter_grads, axis=1)\n\n        # Derivative of the radius potential\n        circle_grads = 10*kernel_points\n\n        # All gradients\n        gradients = inter_grads + circle_grads\n\n        if fixed == 'verticals':\n            gradients[:, 1:3, :-1] = 0\n\n        # Stop condition\n        # **************\n\n        # Compute norm of gradients\n        gradients_norms = np.sqrt(np.sum(np.power(gradients, 2), axis=-1))\n        saved_gradient_norms[iter, :] = np.max(gradients_norms, axis=1)\n\n        # Stop if all moving points are gradients fixed (low gradients diff)\n\n        if fixed == 'center' and np.max(np.abs(old_gradient_norms[:, 1:] - gradients_norms[:, 1:])) < thresh:\n            break\n        elif fixed == 'verticals' and np.max(np.abs(old_gradient_norms[:, 3:] - gradients_norms[:, 3:])) < thresh:\n            break\n        elif np.max(np.abs(old_gradient_norms - gradients_norms)) < thresh:\n            break\n        old_gradient_norms = gradients_norms\n\n        # Move points\n        # ***********\n\n        # Clip gradient to get moving dists\n        moving_dists = np.minimum(moving_factor * gradients_norms, clip)\n\n        # Fix central point\n        if fixed == 'center':\n            moving_dists[:, 0] = 0\n        if fixed == 'verticals':\n            moving_dists[:, 0] = 0\n\n        # Move points\n        kernel_points -= np.expand_dims(moving_dists, -1) * gradients / np.expand_dims(gradients_norms + 1e-6, -1)\n\n        if verbose:\n            print('iter {:5d} / max grad = {:f}'.format(iter, np.max(gradients_norms[:, 3:])))\n        if verbose > 1:\n            plt.clf()\n            plt.plot(kernel_points[0, :, 0], kernel_points[0, :, 1], '.')\n            circle = plt.Circle((0, 0), radius, color='r', fill=False)\n            fig.axes[0].add_artist(circle)\n            fig.axes[0].set_xlim((-radius*1.1, radius*1.1))\n            fig.axes[0].set_ylim((-radius*1.1, radius*1.1))\n            fig.axes[0].set_aspect('equal')\n            plt.draw()\n            plt.pause(0.001)\n            plt.show(block=False)\n            print(moving_factor)\n\n        # moving factor decay\n        moving_factor *= continuous_moving_decay\n\n    # Rescale radius to fit the wanted ratio of radius\n    r = np.sqrt(np.sum(np.power(kernel_points, 2), axis=-1))\n    kernel_points *= ratio / np.mean(r[:, 1:])\n\n    # Rescale kernels with real radius\n    return kernel_points * radius, saved_gradient_norms\n\n\ndef load_kernels(radius, num_kpoints, dimension, fixed, lloyd=False):\n\n    # Kernel directory\n    kernel_dir = 'kernels/dispositions'\n    if not exists(kernel_dir):\n        makedirs(kernel_dir)\n\n    # To many points switch to Lloyds\n    if num_kpoints > 30:\n        lloyd = True\n\n    # Kernel_file\n    kernel_file = join(kernel_dir, 'k_{:03d}_{:s}_{:d}D.ply'.format(num_kpoints, fixed, dimension))\n\n    # Check if already done\n    if not exists(kernel_file):\n        if lloyd:\n            # Create kernels\n            kernel_points = spherical_Lloyd(1.0,\n                                            num_kpoints,\n                                            dimension=dimension,\n                                            fixed=fixed,\n                                            verbose=0)\n\n        else:\n            # Create kernels\n        \n            kernel_points, grad_norms = kernel_point_optimization_debug(1.0,\n                                                                        num_kpoints,\n                                                                        num_kernels=100,\n                                                                        dimension=dimension,\n                                                                        fixed=fixed,\n                                                                        verbose=0)\n\n            # Find best candidate\n            best_k = np.argmin(grad_norms[-1, :])\n\n            # Save points\n            kernel_points = kernel_points[best_k, :, :]\n\n        write_ply(kernel_file, kernel_points, ['x', 'y', 'z'])\n\n    else:\n        data = read_ply(kernel_file)\n        kernel_points = np.vstack((data['x'], data['y'], data['z'])).T\n\n    # Random roations for the kernel\n    # N.B. 4D random rotations not supported yet\n    R = np.eye(dimension)\n    theta = np.random.rand() * 2 * np.pi\n    if dimension == 2:\n        if fixed != 'vertical':\n            c, s = np.cos(theta), np.sin(theta)\n            R = np.array([[c, -s], [s, c]], dtype=np.float32)\n\n    elif dimension == 3:\n        if fixed != 'vertical':\n            c, s = np.cos(theta), np.sin(theta)\n            R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)\n\n        else:\n            phi = (np.random.rand() - 0.5) * np.pi\n\n            # Create the first vector in carthesian coordinates\n            u = np.array([np.cos(theta) * np.cos(phi), np.sin(theta) * np.cos(phi), np.sin(phi)])\n\n            # Choose a random rotation angle\n            alpha = np.random.rand() * 2 * np.pi\n\n            # Create the rotation matrix with this vector and angle\n            R = create_3D_rotations(np.reshape(u, (1, -1)), np.reshape(alpha, (1, -1)))[0]\n\n            R = R.astype(np.float32)\n\n    # Add a small noise\n    kernel_points = kernel_points + np.random.normal(scale=0.01, size=kernel_points.shape)\n\n    # Scale kernels\n    kernel_points = radius * kernel_points\n\n    # Rotate kernels\n    kernel_points = np.matmul(kernel_points, R)\n\n    return kernel_points.astype(np.float32)\n"
  },
  {
    "path": "thirdparty/kpconv/kpconv_blocks.py",
    "content": "#\n#\n#      0=================================0\n#      |    Kernel Point Convolutions    |\n#      0=================================0\n#\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      Define network blocks\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      Hugues THOMAS - 06/03/2020\n\nimport time\nimport math\nimport torch\nimport torch.nn as nn\nfrom torch.nn.parameter import Parameter\nfrom torch.nn.init import kaiming_uniform_\n# from kernels.kernel_points import load_kernels\nfrom kpconv.kernels.kernel_points import load_kernels\n\n# from lib.ply import write_ply\nfrom kpconv.lib.ply import write_ply\n\n\ndef gather(x, idx, method=2):\n    \"\"\"\n    implementation of a custom gather operation for faster backwards.\n    :param x: input with shape [N, D_1, ... D_d]\n    :param idx: indexing with shape [n_1, ..., n_m]\n    :param method: Choice of the method\n    :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d]\n    \"\"\"\n\n    if method == 0:\n        return x[idx]\n    elif method == 1:\n        x = x.unsqueeze(1)\n        x = x.expand((-1, idx.shape[-1], -1))\n        idx = idx.unsqueeze(2)\n        idx = idx.expand((-1, -1, x.shape[-1]))\n        return x.gather(0, idx)\n    elif method == 2:\n        for i, ni in enumerate(idx.size()[1:]):\n            x = x.unsqueeze(i+1)\n            new_s = list(x.size())\n            new_s[i+1] = ni\n            x = x.expand(new_s)\n        n = len(idx.size())\n        for i, di in enumerate(x.size()[n:]):\n            idx = idx.unsqueeze(i+n)\n            new_s = list(idx.size())\n            new_s[i+n] = di\n            idx = idx.expand(new_s)\n        return x.gather(0, idx)\n    else:\n        raise ValueError('Unkown method')\n\n\ndef radius_gaussian(sq_r, sig, eps=1e-9):\n    \"\"\"\n    Compute a radius gaussian (gaussian of distance)\n    :param sq_r: input radiuses [dn, ..., d1, d0]\n    :param sig: extents of gaussians [d1, d0] or [d0] or float\n    :return: gaussian of sq_r [dn, ..., d1, d0]\n    \"\"\"\n    return torch.exp(-sq_r / (2 * sig**2 + eps))\n\n\ndef closest_pool(x, inds):\n    \"\"\"\n    Pools features from the closest neighbors. WARNING: this function assumes the neighbors are ordered.\n    :param x: [n1, d] features matrix\n    :param inds: [n2, max_num] Only the first column is used for pooling\n    :return: [n2, d] pooled features matrix\n    \"\"\"\n\n    # Add a last row with minimum features for shadow pools\n    x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)\n\n    # Get features for each pooling location [n2, d]\n    return gather(x, inds[:, 0])\n\n\ndef max_pool(x, inds):\n    \"\"\"\n    Pools features with the maximum values.\n    :param x: [n1, d] features matrix\n    :param inds: [n2, max_num] pooling indices\n    :return: [n2, d] pooled features matrix\n    \"\"\"\n\n    # Add a last row with minimum features for shadow pools\n    x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)\n\n    # Get all features for each pooling location [n2, max_num, d]\n    pool_features = gather(x, inds)\n\n    # Pool the maximum [n2, d]\n    max_features, _ = torch.max(pool_features, 1)\n    return max_features\n\n\ndef global_average(x, batch_lengths):\n    \"\"\"\n    Block performing a global average over batch pooling\n    :param x: [N, D] input features\n    :param batch_lengths: [B] list of batch lengths\n    :return: [B, D] averaged features\n    \"\"\"\n\n    # Loop over the clouds of the batch\n    averaged_features = []\n    i0 = 0\n    for b_i, length in enumerate(batch_lengths):\n\n        # Average features for each batch cloud\n        averaged_features.append(torch.mean(x[i0:i0 + length], dim=0))\n\n        # Increment for next cloud\n        i0 += length\n\n    # Average features in each batch\n    return torch.stack(averaged_features)\n\n\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#           KPConv class\n#       \\******************/\n#\n\n\nclass KPConv(nn.Module):\n\n    def __init__(self, kernel_size, p_dim, in_channels, out_channels, KP_extent, radius,\n                 fixed_kernel_points='center', KP_influence='linear', aggregation_mode='sum',\n                 deformable=False, modulated=False):\n        \"\"\"\n        Initialize parameters for KPConvDeformable.\n        :param kernel_size: Number of kernel points.\n        :param p_dim: dimension of the point space.\n        :param in_channels: dimension of input features.\n        :param out_channels: dimension of output features.\n        :param KP_extent: influence radius of each kernel point.\n        :param radius: radius used for kernel point init. Even for deformable, use the config.conv_radius\n        :param fixed_kernel_points: fix position of certain kernel points ('none', 'center' or 'verticals').\n        :param KP_influence: influence function of the kernel points ('constant', 'linear', 'gaussian').\n        :param aggregation_mode: choose to sum influences, or only keep the closest ('closest', 'sum').\n        :param deformable: choose deformable or not\n        :param modulated: choose if kernel weights are modulated in addition to deformed\n        \"\"\"\n        super(KPConv, self).__init__()\n\n        # Save parameters\n        self.K = kernel_size\n        self.p_dim = p_dim\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.radius = radius\n        self.KP_extent = KP_extent\n        self.fixed_kernel_points = fixed_kernel_points\n        self.KP_influence = KP_influence\n        self.aggregation_mode = aggregation_mode\n        self.deformable = deformable\n        self.modulated = modulated\n\n        # Running variable containing deformed KP distance to input points. (used in regularization loss)\n        self.min_d2 = None\n        self.deformed_KP = None\n        self.offset_features = None\n\n        # Initialize weights\n        self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),\n                                 requires_grad=True)\n\n        # Initiate weights for offsets\n        if deformable:\n            if modulated:\n                self.offset_dim = (self.p_dim + 1) * self.K\n            else:\n                self.offset_dim = self.p_dim * self.K\n            self.offset_conv = KPConv(self.K,\n                                      self.p_dim,\n                                      self.in_channels,\n                                      self.offset_dim,\n                                      KP_extent,\n                                      radius,\n                                      fixed_kernel_points=fixed_kernel_points,\n                                      KP_influence=KP_influence,\n                                      aggregation_mode=aggregation_mode)\n            self.offset_bias = Parameter(torch.zeros(self.offset_dim, dtype=torch.float32), requires_grad=True)\n\n        else:\n            self.offset_dim = None\n            self.offset_conv = None\n            self.offset_bias = None\n\n        # Reset parameters\n        self.reset_parameters()\n        \n        # Initialize kernel points\n        self.kernel_points = self.init_KP()\n\n        return\n\n    def reset_parameters(self):\n        kaiming_uniform_(self.weights, a=math.sqrt(5))\n        if self.deformable:\n            nn.init.zeros_(self.offset_bias)\n        return\n\n    def init_KP(self):\n        \"\"\"\n        Initialize the kernel point positions in a sphere\n        :return: the tensor of kernel points\n        \"\"\"\n\n        # Create one kernel disposition (as numpy array). Choose the KP distance to center thanks to the KP extent\n        K_points_numpy = load_kernels(self.radius,\n                                      self.K,\n                                      dimension=self.p_dim,\n                                      fixed=self.fixed_kernel_points)\n\n        return Parameter(torch.tensor(K_points_numpy, dtype=torch.float32),\n                         requires_grad=False)\n\n    def forward(self, q_pts, s_pts, neighb_inds, x):\n\n        ###################\n        # Offset generation\n        ###################\n\n        if self.deformable:\n\n            # Get offsets with a KPConv that only takes part of the features\n            self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias\n\n            if self.modulated:\n\n                # Get offset (in normalized scale) from features\n                unscaled_offsets = self.offset_features[:, :self.p_dim * self.K]\n                unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim)\n\n                # Get modulations\n                modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:])\n\n            else:\n\n                # Get offset (in normalized scale) from features\n                unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim)\n\n                # No modulations\n                modulations = None\n\n            # Rescale offset for this layer\n            offsets = unscaled_offsets * self.KP_extent\n\n        else:\n            offsets = None\n            modulations = None\n\n        ######################\n        # Deformed convolution\n        ######################\n\n        # Add a fake point in the last row for shadow neighbors\n        s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0)\n\n        # Get neighbor points [n_points, n_neighbors, dim]\n        neighbors = s_pts[neighb_inds, :]\n\n        # Center every neighborhood\n        neighbors = neighbors - q_pts.unsqueeze(1)\n\n        # Apply offsets to kernel points [n_points, n_kpoints, dim]\n        if self.deformable:\n            self.deformed_KP = offsets + self.kernel_points\n            deformed_K_points = self.deformed_KP.unsqueeze(1)\n        else:\n            deformed_K_points = self.kernel_points\n\n        # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim]\n        neighbors.unsqueeze_(2)\n        differences = neighbors - deformed_K_points\n\n        # Get the square distances [n_points, n_neighbors, n_kpoints]\n        sq_distances = torch.sum(differences ** 2, dim=3)\n\n        # Optimization by ignoring points outside a deformed KP range\n        if self.deformable:\n\n            # Save distances for loss\n            self.min_d2, _ = torch.min(sq_distances, dim=1)\n\n            # Boolean of the neighbors in range of a kernel point [n_points, n_neighbors]\n            in_range = torch.any(sq_distances < self.KP_extent ** 2, dim=2).type(torch.int32)\n\n            # New value of max neighbors\n            new_max_neighb = torch.max(torch.sum(in_range, dim=1))\n\n            # For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb]\n            neighb_row_bool, neighb_row_inds = torch.topk(in_range, new_max_neighb.item(), dim=1)\n\n            # Gather new neighbor indices [n_points, new_max_neighb]\n            new_neighb_inds = neighb_inds.gather(1, neighb_row_inds, sparse_grad=False)\n\n            # Gather new distances to KP [n_points, new_max_neighb, n_kpoints]\n            neighb_row_inds.unsqueeze_(2)\n            neighb_row_inds = neighb_row_inds.expand(-1, -1, self.K)\n            sq_distances = sq_distances.gather(1, neighb_row_inds, sparse_grad=False)\n\n            # New shadow neighbors have to point to the last shadow point\n            new_neighb_inds *= neighb_row_bool\n            new_neighb_inds -= (neighb_row_bool.type(torch.int64) - 1) * int(s_pts.shape[0] - 1)\n        else:\n            new_neighb_inds = neighb_inds\n\n        # Get Kernel point influences [n_points, n_kpoints, n_neighbors]\n        if self.KP_influence == 'constant':\n            # Every point get an influence of 1.\n            all_weights = torch.ones_like(sq_distances)\n            all_weights = torch.transpose(all_weights, 1, 2)\n\n        elif self.KP_influence == 'linear':\n            # Influence decrease linearly with the distance, and get to zero when d = KP_extent.\n            all_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0)\n            all_weights = torch.transpose(all_weights, 1, 2)\n\n        elif self.KP_influence == 'gaussian':\n            # Influence in gaussian of the distance.\n            sigma = self.KP_extent * 0.3\n            all_weights = radius_gaussian(sq_distances, sigma)\n            all_weights = torch.transpose(all_weights, 1, 2)\n        else:\n            raise ValueError('Unknown influence function type (config.KP_influence)')\n\n        # In case of closest mode, only the closest KP can influence each point\n        if self.aggregation_mode == 'closest':\n            neighbors_1nn = torch.argmin(sq_distances, dim=2)\n            all_weights *= torch.transpose(nn.functional.one_hot(neighbors_1nn, self.K), 1, 2)\n\n        elif self.aggregation_mode != 'sum':\n            raise ValueError(\"Unknown convolution mode. Should be 'closest' or 'sum'\")\n\n        # Add a zero feature for shadow neighbors\n        x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)\n\n        # Get the features of each neighborhood [n_points, n_neighbors, in_fdim]\n        neighb_x = gather(x, new_neighb_inds)\n\n        # Apply distance weights [n_points, n_kpoints, in_fdim]\n        weighted_features = torch.matmul(all_weights, neighb_x)\n\n        # Apply modulations\n        if self.deformable and self.modulated:\n            weighted_features *= modulations.unsqueeze(2)\n\n        # Apply network weights [n_kpoints, n_points, out_fdim]\n        weighted_features = weighted_features.permute((1, 0, 2))\n        \n        kernel_outputs = torch.matmul(weighted_features, self.weights)\n\n        # Convolution sum [n_points, out_fdim]\n        # return torch.sum(kernel_outputs, dim=0)\n        output_features = torch.sum(kernel_outputs, dim=0, keepdim=False)\n\n        # normalization term.\n        neighbor_features_sum = torch.sum(neighb_x, dim=-1)\n        neighbor_num = torch.sum(torch.gt(neighbor_features_sum, 0.0), dim=-1)\n        neighbor_num = torch.max(neighbor_num, torch.ones_like(neighbor_num))\n        output_features = output_features / neighbor_num.unsqueeze(1)\n\n        return output_features\n\n    def __repr__(self):\n        return 'KPConv(radius: {:.2f}, extent: {:.2f}, in_feat: {:d}, out_feat: {:d})'.format(self.radius, self.KP_extent,\n                                                                              self.in_channels,\n                                                                              self.out_channels)\n\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#           Complex blocks\n#       \\********************/\n#\n\ndef block_decider(block_name,\n                  radius,\n                  in_dim,\n                  out_dim,\n                  layer_ind,\n                  config):\n\n    if block_name == 'unary':\n        return UnaryBlock(in_dim, out_dim, config.use_batch_norm, config.batch_norm_momentum)\n    \n    if block_name == 'last_unary':\n        return LastUnaryBlock(in_dim, config.final_feats_dim+2, config.use_batch_norm, config.batch_norm_momentum)\n\n    if block_name == 'last_unary_v2':\n        return LastUnaryBlock(in_dim, config.final_feats_dim+config.keypoint_num+1, config.use_batch_norm, config.batch_norm_momentum)\n    # if block_name == 'last_unary_pose':\n    #     return LastUnaryBlock(in_dim, 7+1, config.use_batch_norm, config.batch_norm_momentum)\n\n\n    elif block_name in ['simple',\n                        'simple_deformable',\n                        'simple_invariant',\n                        'simple_equivariant',\n                        'simple_strided',\n                        'simple_deformable_strided',\n                        'simple_invariant_strided',\n                        'simple_equivariant_strided']:\n        return SimpleBlock(block_name, in_dim, out_dim, radius, layer_ind, config)\n\n    elif block_name in ['resnetb',\n                        'resnetb_invariant',\n                        'resnetb_equivariant',\n                        'resnetb_deformable',\n                        'resnetb_strided',\n                        'resnetb_deformable_strided',\n                        'resnetb_equivariant_strided',\n                        'resnetb_invariant_strided']:\n        return ResnetBottleneckBlock(block_name, in_dim, out_dim, radius, layer_ind, config)\n\n    elif block_name == 'max_pool' or block_name == 'max_pool_wide':\n        return MaxPoolBlock(layer_ind)\n\n    elif block_name == 'global_average':\n        return GlobalAverageBlock()\n\n    elif block_name == 'nearest_upsample':\n        return NearestUpsampleBlock(layer_ind)\n\n    else:\n        raise ValueError('Unknown block name in the architecture definition : ' + block_name)\n\n\nclass BatchNormBlock(nn.Module):\n\n    def __init__(self, in_dim, use_bn, bn_momentum):\n        \"\"\"\n        Initialize a batch normalization block. If network does not use batch normalization, replace with biases.\n        :param in_dim: dimension input features\n        :param use_bn: boolean indicating if we use Batch Norm\n        :param bn_momentum: Batch norm momentum\n        \"\"\"\n        super(BatchNormBlock, self).__init__()\n        self.bn_momentum = bn_momentum\n        self.use_bn = use_bn\n        self.in_dim = in_dim\n        if self.use_bn:\n            #self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)\n            self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)\n        else:\n            self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True)\n        return\n\n    def reset_parameters(self):\n        nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        if self.use_bn:\n\n            x = x.unsqueeze(2)\n            x = x.transpose(0, 2)\n            x = self.batch_norm(x)\n            x = x.transpose(0, 2)\n            return x.squeeze()\n        else:\n            return x + self.bias\n\n    def __repr__(self):\n        return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,\n                                                                                         self.bn_momentum,\n                                                                                         str(not self.use_bn))\n\n\nclass UnaryBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):\n        \"\"\"\n        Initialize a standard unary block with its ReLU and BatchNorm.\n        :param in_dim: dimension input features\n        :param out_dim: dimension input features\n        :param use_bn: boolean indicating if we use Batch Norm\n        :param bn_momentum: Batch norm momentum\n        \"\"\"\n\n        super(UnaryBlock, self).__init__()\n        self.bn_momentum = bn_momentum\n        self.use_bn = use_bn\n        self.no_relu = no_relu\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.mlp = nn.Linear(in_dim, out_dim, bias=False)\n        self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)\n        if not no_relu:\n            self.leaky_relu = nn.LeakyReLU(0.1)\n        return\n\n    def forward(self, x, batch=None):\n        x = self.mlp(x)\n        x = self.batch_norm(x)\n        if not self.no_relu:\n            x = self.leaky_relu(x)\n        return x\n\n    def __repr__(self):\n        return 'UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})'.format(self.in_dim,\n                                                                                        self.out_dim,\n                                                                                        str(self.use_bn),\n                                                                                        str(not self.no_relu))\n\n\nclass LastUnaryBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):\n        \"\"\"\n        Initialize a standard last_unary block without BN, ReLU.\n        :param in_dim: dimension input features\n        :param out_dim: dimension input features\n        :param use_bn: boolean indicating if we use Batch Norm\n        :param bn_momentum: Batch norm momentum\n        \"\"\"\n\n        super(LastUnaryBlock, self).__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n        self.mlp = nn.Linear(in_dim, out_dim, bias=False)\n        return\n\n    def forward(self, x, batch=None):\n        x = self.mlp(x)\n        return x\n\n    def __repr__(self):\n        return 'LastUnaryBlock(in_feat: {:d}, out_feat: {:d})'.format(self.in_dim,\n                                                                      self.out_dim)\n\n\nclass SimpleBlock(nn.Module):\n\n    def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):\n        \"\"\"\n        Initialize a simple convolution block with its ReLU and BatchNorm.\n        :param in_dim: dimension input features\n        :param out_dim: dimension input features\n        :param radius: current radius of convolution\n        :param config: parameters\n        \"\"\"\n        super(SimpleBlock, self).__init__()\n\n        # get KP_extent from current radius\n        current_extent = radius * config.KP_extent / config.conv_radius\n\n        # Get other parameters\n        self.bn_momentum = config.batch_norm_momentum\n        self.use_bn = config.use_batch_norm\n        self.layer_ind = layer_ind\n        self.block_name = block_name\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # Define the KPConv class\n        self.KPConv = KPConv(config.num_kernel_points,\n                             config.in_points_dim,\n                             in_dim,\n                             out_dim // 2,\n                             current_extent,\n                             radius,\n                             fixed_kernel_points=config.fixed_kernel_points,\n                             KP_influence=config.KP_influence,\n                             aggregation_mode=config.aggregation_mode,\n                             deformable='deform' in block_name,\n                             modulated=config.modulated)\n\n        # Other opperations\n        self.batch_norm = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum)\n        self.leaky_relu = nn.LeakyReLU(0.1)\n\n        return\n\n    def forward(self, x, batch):\n\n        if 'strided' in self.block_name:\n            q_pts = batch['points'][self.layer_ind + 1]\n            s_pts = batch['points'][self.layer_ind]\n            neighb_inds = batch['pools'][self.layer_ind]\n        else:\n            q_pts = batch['points'][self.layer_ind]\n            s_pts = batch['points'][self.layer_ind]\n            neighb_inds = batch['neighbors'][self.layer_ind]\n\n        x = self.KPConv(q_pts, s_pts, neighb_inds, x)\n        return self.leaky_relu(self.batch_norm(x))\n\n\nclass ResnetBottleneckBlock(nn.Module):\n\n    def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):\n        \"\"\"\n        Initialize a resnet bottleneck block.\n        :param in_dim: dimension input features\n        :param out_dim: dimension input features\n        :param radius: current radius of convolution\n        :param config: parameters\n        \"\"\"\n        super(ResnetBottleneckBlock, self).__init__()\n\n        # get KP_extent from current radius\n        current_extent = radius * config.KP_extent / config.conv_radius\n\n        # Get other parameters\n        self.bn_momentum = config.batch_norm_momentum\n        self.use_bn = config.use_batch_norm\n        self.block_name = block_name\n        self.layer_ind = layer_ind\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # First downscaling mlp\n        if in_dim != out_dim // 4:\n            self.unary1 = UnaryBlock(in_dim, out_dim // 4, self.use_bn, self.bn_momentum)\n        else:\n            self.unary1 = nn.Identity()\n\n        # KPConv block\n        self.KPConv = KPConv(config.num_kernel_points,\n                             config.in_points_dim,\n                             out_dim // 4,\n                             out_dim // 4,\n                             current_extent,\n                             radius,\n                             fixed_kernel_points=config.fixed_kernel_points,\n                             KP_influence=config.KP_influence,\n                             aggregation_mode=config.aggregation_mode,\n                             deformable='deform' in block_name,\n                             modulated=config.modulated)\n        self.batch_norm_conv = BatchNormBlock(out_dim // 4, self.use_bn, self.bn_momentum)\n\n        # Second upscaling mlp\n        self.unary2 = UnaryBlock(out_dim // 4, out_dim, self.use_bn, self.bn_momentum, no_relu=True)\n\n        # Shortcut optional mpl\n        if in_dim != out_dim:\n            self.unary_shortcut = UnaryBlock(in_dim, out_dim, self.use_bn, self.bn_momentum, no_relu=True)\n        else:\n            self.unary_shortcut = nn.Identity()\n\n        # Other operations\n        self.leaky_relu = nn.LeakyReLU(0.1)\n\n        return\n\n    def forward(self, features, batch):\n\n        if 'strided' in self.block_name:\n            q_pts = batch['points'][self.layer_ind + 1]\n            s_pts = batch['points'][self.layer_ind]\n            neighb_inds = batch['pools'][self.layer_ind]\n        else:\n            q_pts = batch['points'][self.layer_ind]\n            s_pts = batch['points'][self.layer_ind]\n            neighb_inds = batch['neighbors'][self.layer_ind]\n\n        # First downscaling mlp\n        x = self.unary1(features)\n\n        # Convolution\n        x = self.KPConv(q_pts, s_pts, neighb_inds, x)\n        x = self.leaky_relu(self.batch_norm_conv(x))\n\n        # Second upscaling mlp\n        x = self.unary2(x)\n\n        # Shortcut\n        if 'strided' in self.block_name:\n            shortcut = max_pool(features, neighb_inds)\n        else:\n            shortcut = features\n        shortcut = self.unary_shortcut(shortcut)\n\n        return self.leaky_relu(x + shortcut)\n\n\nclass GlobalAverageBlock(nn.Module):\n\n    def __init__(self):\n        \"\"\"\n        Initialize a global average block with its ReLU and BatchNorm.\n        \"\"\"\n        super(GlobalAverageBlock, self).__init__()\n        return\n\n    def forward(self, x, batch):\n        return global_average(x, batch['stack_lengths'][-1])\n\n\nclass NearestUpsampleBlock(nn.Module):\n\n    def __init__(self, layer_ind):\n        \"\"\"\n        Initialize a nearest upsampling block with its ReLU and BatchNorm.\n        \"\"\"\n        super(NearestUpsampleBlock, self).__init__()\n        self.layer_ind = layer_ind\n        return\n\n    def forward(self, x, batch):\n        return closest_pool(x, batch['upsamples'][self.layer_ind - 1])\n\n    def __repr__(self):\n        return 'NearestUpsampleBlock(layer: {:d} -> {:d})'.format(self.layer_ind,\n                                                                  self.layer_ind - 1)\n\n\nclass MaxPoolBlock(nn.Module):\n\n    def __init__(self, layer_ind):\n        \"\"\"\n        Initialize a max pooling block with its ReLU and BatchNorm.\n        \"\"\"\n        super(MaxPoolBlock, self).__init__()\n        self.layer_ind = layer_ind\n        return\n\n    def forward(self, x, batch):\n        return max_pool(x, batch['pools'][self.layer_ind + 1])\n"
  },
  {
    "path": "thirdparty/kpconv/lib/__init__.py",
    "content": ""
  },
  {
    "path": "thirdparty/kpconv/lib/ply.py",
    "content": "#\n#\n#      0===============================0\n#      |    PLY files reader/writer    |\n#      0===============================0\n#\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      function to read/write .ply files\n#\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#      Hugues THOMAS - 10/02/2017\n#\n\n\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#          Imports and global variables\n#      \\**********************************/\n#\n\n# Basic libs\nimport numpy as np\nimport sys\n\n# Define PLY types\nply_dtypes = dict([\n    (b'int8', 'i1'),\n    (b'char', 'i1'),\n    (b'uint8', 'u1'),\n    (b'uchar', 'u1'),\n    (b'int16', 'i2'),\n    (b'short', 'i2'),\n    (b'uint16', 'u2'),\n    (b'ushort', 'u2'),\n    (b'int32', 'i4'),\n    (b'int', 'i4'),\n    (b'uint32', 'u4'),\n    (b'uint', 'u4'),\n    (b'float32', 'f4'),\n    (b'float', 'f4'),\n    (b'float64', 'f8'),\n    (b'double', 'f8')\n])\n\n# Numpy reader format\nvalid_formats = {'ascii': '', 'binary_big_endian': '>',\n                 'binary_little_endian': '<'}\n\n\n# ----------------------------------------------------------------------------------------------------------------------\n#\n#           Functions\n#       \\***************/\n#\n\n\ndef parse_header(plyfile, ext):\n    # Variables\n    line = []\n    properties = []\n    num_points = None\n\n    while b'end_header' not in line and line != b'':\n        line = plyfile.readline()\n\n        if b'element' in line:\n            line = line.split()\n            num_points = int(line[2])\n\n        elif b'property' in line:\n            line = line.split()\n            properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))\n\n    return num_points, properties\n\n\ndef parse_mesh_header(plyfile, ext):\n    # Variables\n    line = []\n    vertex_properties = []\n    num_points = None\n    num_faces = None\n    current_element = None\n\n    while b'end_header' not in line and line != b'':\n        line = plyfile.readline()\n\n        # Find point element\n        if b'element vertex' in line:\n            current_element = 'vertex'\n            line = line.split()\n            num_points = int(line[2])\n\n        elif b'element face' in line:\n            current_element = 'face'\n            line = line.split()\n            num_faces = int(line[2])\n\n        elif b'property' in line:\n            if current_element == 'vertex':\n                line = line.split()\n                vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))\n            elif current_element == 'vertex':\n                if not line.startswith('property list uchar int'):\n                    raise ValueError('Unsupported faces property : ' + line)\n\n    return num_points, num_faces, vertex_properties\n\n\ndef read_ply(filename, triangular_mesh=False):\n    \"\"\"\n    Read \".ply\" files\n\n    Parameters\n    ----------\n    filename : string\n        the name of the file to read.\n\n    Returns\n    -------\n    result : array\n        data stored in the file\n\n    Examples\n    --------\n    Store data in file\n\n    >>> points = np.random.rand(5, 3)\n    >>> values = np.random.randint(2, size=10)\n    >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values'])\n\n    Read the file\n\n    >>> data = read_ply('example.ply')\n    >>> values = data['values']\n    array([0, 0, 1, 1, 0])\n\n    >>> points = np.vstack((data['x'], data['y'], data['z'])).T\n    array([[ 0.466  0.595  0.324]\n           [ 0.538  0.407  0.654]\n           [ 0.850  0.018  0.988]\n           [ 0.395  0.394  0.363]\n           [ 0.873  0.996  0.092]])\n\n    \"\"\"\n\n    with open(filename, 'rb') as plyfile:\n\n        # Check if the file start with ply\n        if b'ply' not in plyfile.readline():\n            raise ValueError('The file does not start whith the word ply')\n\n        # get binary_little/big or ascii\n        fmt = plyfile.readline().split()[1].decode()\n        if fmt == \"ascii\":\n            raise ValueError('The file is not binary')\n\n        # get extension for building the numpy dtypes\n        ext = valid_formats[fmt]\n\n        # PointCloud reader vs mesh reader\n        if triangular_mesh:\n\n            # Parse header\n            num_points, num_faces, properties = parse_mesh_header(plyfile, ext)\n\n            # Get point data\n            vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points)\n\n            # Get face data\n            face_properties = [('k', ext + 'u1'),\n                               ('v1', ext + 'i4'),\n                               ('v2', ext + 'i4'),\n                               ('v3', ext + 'i4')]\n            faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces)\n\n            # Return vertex data and concatenated faces\n            faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T\n            data = [vertex_data, faces]\n\n        else:\n\n            # Parse header\n            num_points, properties = parse_header(plyfile, ext)\n\n            # Get data\n            data = np.fromfile(plyfile, dtype=properties, count=num_points)\n\n    return data\n\n\ndef header_properties(field_list, field_names):\n    # List of lines to write\n    lines = []\n\n    # First line describing element vertex\n    lines.append('element vertex %d' % field_list[0].shape[0])\n\n    # Properties lines\n    i = 0\n    for fields in field_list:\n        for field in fields.T:\n            lines.append('property %s %s' % (field.dtype.name, field_names[i]))\n            i += 1\n\n    return lines\n\n\ndef write_ply(filename, field_list, field_names, triangular_faces=None):\n    \"\"\"\n    Write \".ply\" files\n\n    Parameters\n    ----------\n    filename : string\n        the name of the file to which the data is saved. A '.ply' extension will be appended to the\n        file name if it does no already have one.\n\n    field_list : list, tuple, numpy array\n        the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a\n        tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered\n        as one field.\n\n    field_names : list\n        the name of each fields as a list of strings. Has to be the same length as the number of\n        fields.\n\n    Examples\n    --------\n    >>> points = np.random.rand(10, 3)\n    >>> write_ply('example1.ply', points, ['x', 'y', 'z'])\n\n    >>> values = np.random.randint(2, size=10)\n    >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values'])\n\n    >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8)\n    >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values']\n    >>> write_ply('example3.ply', [points, colors, values], field_names)\n\n    \"\"\"\n\n    # Format list input to the right form\n    field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,))\n    for i, field in enumerate(field_list):\n        if field.ndim < 2:\n            field_list[i] = field.reshape(-1, 1)\n        if field.ndim > 2:\n            print('fields have more than 2 dimensions')\n            return False\n\n            # check all fields have the same number of data\n    n_points = [field.shape[0] for field in field_list]\n    if not np.all(np.equal(n_points, n_points[0])):\n        print('wrong field dimensions')\n        return False\n\n        # Check if field_names and field_list have same nb of column\n    n_fields = np.sum([field.shape[1] for field in field_list])\n    if (n_fields != len(field_names)):\n        print('wrong number of field names')\n        return False\n\n    # Add extension if not there\n    if not filename.endswith('.ply'):\n        filename += '.ply'\n\n    # open in text mode to write the header\n    with open(filename, 'w') as plyfile:\n\n        # First magical word\n        header = ['ply']\n\n        # Encoding format\n        header.append('format binary_' + sys.byteorder + '_endian 1.0')\n\n        # Points properties description\n        header.extend(header_properties(field_list, field_names))\n\n        # Add faces if needded\n        if triangular_faces is not None:\n            header.append('element face {:d}'.format(triangular_faces.shape[0]))\n            header.append('property list uchar int vertex_indices')\n\n        # End of header\n        header.append('end_header')\n\n        # Write all lines\n        for line in header:\n            plyfile.write(\"%s\\n\" % line)\n\n    # open in binary/append to use tofile\n    with open(filename, 'ab') as plyfile:\n\n        # Create a structured array\n        i = 0\n        type_list = []\n        for fields in field_list:\n            for field in fields.T:\n                type_list += [(field_names[i], field.dtype.str)]\n                i += 1\n        data = np.empty(field_list[0].shape[0], dtype=type_list)\n        i = 0\n        for fields in field_list:\n            for field in fields.T:\n                data[field_names[i]] = field\n                i += 1\n\n        data.tofile(plyfile)\n\n        if triangular_faces is not None:\n            triangular_faces = triangular_faces.astype(np.int32)\n            type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)]\n            data = np.empty(triangular_faces.shape[0], dtype=type_list)\n            data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8)\n            data['0'] = triangular_faces[:, 0]\n            data['1'] = triangular_faces[:, 1]\n            data['2'] = triangular_faces[:, 2]\n            data.tofile(plyfile)\n\n    return True\n\n\ndef describe_element(name, df):\n    \"\"\" Takes the columns of the dataframe and builds a ply-like description\n\n    Parameters\n    ----------\n    name: str\n    df: pandas DataFrame\n\n    Returns\n    -------\n    element: list[str]\n    \"\"\"\n    property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'}\n    element = ['element ' + name + ' ' + str(len(df))]\n\n    if name == 'face':\n        element.append(\"property list uchar int points_indices\")\n\n    else:\n        for i in range(len(df.columns)):\n            # get first letter of dtype to infer format\n            f = property_formats[str(df.dtypes[i])[0]]\n            element.append('property ' + f + ' ' + df.columns.values[i])\n\n    return element\n\n"
  },
  {
    "path": "thirdparty/kpconv/lib/timer.py",
    "content": "import time\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0.0\n        self.sq_sum = 0.0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n        self.sq_sum += val ** 2 * n\n        self.var = self.sq_sum / self.count - self.avg ** 2\n\n\nclass Timer(object):\n    \"\"\"A simple timer.\"\"\"\n\n    def __init__(self):\n        self.total_time = 0.\n        self.calls = 0\n        self.start_time = 0.\n        self.diff = 0.\n        self.avg = 0.\n\n    def reset(self):\n        self.total_time = 0\n        self.calls = 0\n        self.start_time = 0\n        self.diff = 0\n        self.avg = 0\n\n    def tic(self):\n        # using time.time instead of time.clock because time time.clock\n        # does not normalize for multithreading\n        self.start_time = time.time()\n\n    def toc(self, average=True):\n        self.diff = time.time() - self.start_time\n        self.total_time += self.diff\n        self.calls += 1\n        self.avg = self.total_time / self.calls\n        if average:\n            return self.avg\n        else:\n            return self.diff\n"
  },
  {
    "path": "thirdparty/kpconv/lib/utils.py",
    "content": "\"\"\"\nGeneral utility functions\n\nAuthor: Shengyu Huang\nLast modified: 30.11.2020\n\"\"\"\n\nimport os,re,sys,json,yaml,random, argparse, torch, pickle\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport numpy as np\nfrom scipy.spatial.transform import Rotation\n\nfrom sklearn.neighbors import NearestNeighbors\nfrom scipy.spatial.distance import minkowski\n_EPS = 1e-7  # To prevent division by zero\n\n\nclass Logger:\n    def __init__(self, path):\n        self.path = path\n        self.fw = open(self.path+'/log','a')\n\n    def write(self, text):\n        self.fw.write(text)\n        self.fw.flush()\n\n    def close(self):\n        self.fw.close()\n\ndef save_obj(obj, path ):\n    \"\"\"\n    save a dictionary to a pickle file\n    \"\"\"\n    with open(path, 'wb') as f:\n        pickle.dump(obj, f)\n\ndef load_obj(path):\n    \"\"\"\n    read a dictionary from a pickle file\n    \"\"\"\n    with open(path, 'rb') as f:\n        return pickle.load(f)\n\ndef load_config(path):\n    \"\"\"\n    Loads config file:\n\n    Args:\n        path (str): path to the config file\n\n    Returns: \n        config (dict): dictionary of the configuration parameters\n\n    \"\"\"\n    with open(path,'r') as f:\n        cfg = yaml.safe_load(f)\n\n    return cfg\n\n\ndef setup_seed(seed):\n    \"\"\"\n    fix random seed for deterministic training\n    \"\"\"\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\ndef square_distance(src, dst, normalised = False):\n    \"\"\"\n    Calculate Euclid distance between each two points.\n    Args:\n        src: source points, [B, N, C]\n        dst: target points, [B, M, C]\n    Returns:\n        dist: per-point square distance, [B, N, M]\n    \"\"\"\n    # print(src.shape, dst.shape, '!!!',flush=True)\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    if(normalised):\n        dist += 2\n    else:\n        dist += torch.sum(src ** 2, dim=-1)[:, :, None]\n        dist += torch.sum(dst ** 2, dim=-1)[:, None, :]\n\n    dist = torch.clamp(dist, min=1e-12, max=None)\n    return dist\n    \n\ndef validate_gradient(model):\n    \"\"\"\n    Confirm all the gradients are non-nan and non-inf\n    \"\"\"\n    for name, param in model.named_parameters():\n        if param.grad is not None:\n            if torch.any(torch.isnan(param.grad)):\n                return False\n            if torch.any(torch.isinf(param.grad)):\n                return False\n    return True\n\n\ndef natural_key(string_):\n    \"\"\"\n    Sort strings by numbers in the name\n    \"\"\"\n    return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_)]"
  },
  {
    "path": "thirdparty/nn/_ext.c",
    "content": "#define _CFFI_\n\n/* We try to define Py_LIMITED_API before including Python.h.\n\n   Mess: we can only define it if Py_DEBUG, Py_TRACE_REFS and\n   Py_REF_DEBUG are not defined.  This is a best-effort approximation:\n   we can learn about Py_DEBUG from pyconfig.h, but it is unclear if\n   the same works for the other two macros.  Py_DEBUG implies them,\n   but not the other way around.\n\n   Issue #350 is still open: on Windows, the code here causes it to link\n   with PYTHON36.DLL (for example) instead of PYTHON3.DLL.  A fix was\n   attempted in 164e526a5515 and 14ce6985e1c3, but reverted: virtualenv\n   does not make PYTHON3.DLL available, and so the \"correctly\" compiled\n   version would not run inside a virtualenv.  We will re-apply the fix\n   after virtualenv has been fixed for some time.  For explanation, see\n   issue #355.  For a workaround if you want PYTHON3.DLL and don't worry\n   about virtualenv, see issue #350.  See also 'py_limited_api' in\n   setuptools_ext.py.\n*/\n#if !defined(_CFFI_USE_EMBEDDING) && !defined(Py_LIMITED_API)\n#  include <pyconfig.h>\n#  if !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG)\n#    define Py_LIMITED_API\n#  endif\n#endif\n\n#include <Python.h>\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n#include <stddef.h>\n\n/* This part is from file 'cffi/parse_c_type.h'.  It is copied at the\n   beginning of C sources generated by CFFI's ffi.set_source(). */\n\ntypedef void *_cffi_opcode_t;\n\n#define _CFFI_OP(opcode, arg)   (_cffi_opcode_t)(opcode | (((uintptr_t)(arg)) << 8))\n#define _CFFI_GETOP(cffi_opcode)    ((unsigned char)(uintptr_t)cffi_opcode)\n#define _CFFI_GETARG(cffi_opcode)   (((intptr_t)cffi_opcode) >> 8)\n\n#define _CFFI_OP_PRIMITIVE       1\n#define _CFFI_OP_POINTER         3\n#define _CFFI_OP_ARRAY           5\n#define _CFFI_OP_OPEN_ARRAY      7\n#define _CFFI_OP_STRUCT_UNION    9\n#define _CFFI_OP_ENUM           11\n#define _CFFI_OP_FUNCTION       13\n#define _CFFI_OP_FUNCTION_END   15\n#define _CFFI_OP_NOOP           17\n#define _CFFI_OP_BITFIELD       19\n#define _CFFI_OP_TYPENAME       21\n#define _CFFI_OP_CPYTHON_BLTN_V 23   // varargs\n#define _CFFI_OP_CPYTHON_BLTN_N 25   // noargs\n#define _CFFI_OP_CPYTHON_BLTN_O 27   // O  (i.e. a single arg)\n#define _CFFI_OP_CONSTANT       29\n#define _CFFI_OP_CONSTANT_INT   31\n#define _CFFI_OP_GLOBAL_VAR     33\n#define _CFFI_OP_DLOPEN_FUNC    35\n#define _CFFI_OP_DLOPEN_CONST   37\n#define _CFFI_OP_GLOBAL_VAR_F   39\n#define _CFFI_OP_EXTERN_PYTHON  41\n\n#define _CFFI_PRIM_VOID          0\n#define _CFFI_PRIM_BOOL          1\n#define _CFFI_PRIM_CHAR          2\n#define _CFFI_PRIM_SCHAR         3\n#define _CFFI_PRIM_UCHAR         4\n#define _CFFI_PRIM_SHORT         5\n#define _CFFI_PRIM_USHORT        6\n#define _CFFI_PRIM_INT           7\n#define _CFFI_PRIM_UINT          8\n#define _CFFI_PRIM_LONG          9\n#define _CFFI_PRIM_ULONG        10\n#define _CFFI_PRIM_LONGLONG     11\n#define _CFFI_PRIM_ULONGLONG    12\n#define _CFFI_PRIM_FLOAT        13\n#define _CFFI_PRIM_DOUBLE       14\n#define _CFFI_PRIM_LONGDOUBLE   15\n\n#define _CFFI_PRIM_WCHAR        16\n#define _CFFI_PRIM_INT8         17\n#define _CFFI_PRIM_UINT8        18\n#define _CFFI_PRIM_INT16        19\n#define _CFFI_PRIM_UINT16       20\n#define _CFFI_PRIM_INT32        21\n#define _CFFI_PRIM_UINT32       22\n#define _CFFI_PRIM_INT64        23\n#define _CFFI_PRIM_UINT64       24\n#define _CFFI_PRIM_INTPTR       25\n#define _CFFI_PRIM_UINTPTR      26\n#define _CFFI_PRIM_PTRDIFF      27\n#define _CFFI_PRIM_SIZE         28\n#define _CFFI_PRIM_SSIZE        29\n#define _CFFI_PRIM_INT_LEAST8   30\n#define _CFFI_PRIM_UINT_LEAST8  31\n#define _CFFI_PRIM_INT_LEAST16  32\n#define _CFFI_PRIM_UINT_LEAST16 33\n#define _CFFI_PRIM_INT_LEAST32  34\n#define _CFFI_PRIM_UINT_LEAST32 35\n#define _CFFI_PRIM_INT_LEAST64  36\n#define _CFFI_PRIM_UINT_LEAST64 37\n#define _CFFI_PRIM_INT_FAST8    38\n#define _CFFI_PRIM_UINT_FAST8   39\n#define _CFFI_PRIM_INT_FAST16   40\n#define _CFFI_PRIM_UINT_FAST16  41\n#define _CFFI_PRIM_INT_FAST32   42\n#define _CFFI_PRIM_UINT_FAST32  43\n#define _CFFI_PRIM_INT_FAST64   44\n#define _CFFI_PRIM_UINT_FAST64  45\n#define _CFFI_PRIM_INTMAX       46\n#define _CFFI_PRIM_UINTMAX      47\n#define _CFFI_PRIM_FLOATCOMPLEX 48\n#define _CFFI_PRIM_DOUBLECOMPLEX 49\n#define _CFFI_PRIM_CHAR16       50\n#define _CFFI_PRIM_CHAR32       51\n\n#define _CFFI__NUM_PRIM         52\n#define _CFFI__UNKNOWN_PRIM           (-1)\n#define _CFFI__UNKNOWN_FLOAT_PRIM     (-2)\n#define _CFFI__UNKNOWN_LONG_DOUBLE    (-3)\n\n#define _CFFI__IO_FILE_STRUCT         (-1)\n\n\nstruct _cffi_global_s {\n    const char *name;\n    void *address;\n    _cffi_opcode_t type_op;\n    void *size_or_direct_fn;  // OP_GLOBAL_VAR: size, or 0 if unknown\n                              // OP_CPYTHON_BLTN_*: addr of direct function\n};\n\nstruct _cffi_getconst_s {\n    unsigned long long value;\n    const struct _cffi_type_context_s *ctx;\n    int gindex;\n};\n\nstruct _cffi_struct_union_s {\n    const char *name;\n    int type_index;          // -> _cffi_types, on a OP_STRUCT_UNION\n    int flags;               // _CFFI_F_* flags below\n    size_t size;\n    int alignment;\n    int first_field_index;   // -> _cffi_fields array\n    int num_fields;\n};\n#define _CFFI_F_UNION         0x01   // is a union, not a struct\n#define _CFFI_F_CHECK_FIELDS  0x02   // complain if fields are not in the\n                                     // \"standard layout\" or if some are missing\n#define _CFFI_F_PACKED        0x04   // for CHECK_FIELDS, assume a packed struct\n#define _CFFI_F_EXTERNAL      0x08   // in some other ffi.include()\n#define _CFFI_F_OPAQUE        0x10   // opaque\n\nstruct _cffi_field_s {\n    const char *name;\n    size_t field_offset;\n    size_t field_size;\n    _cffi_opcode_t field_type_op;\n};\n\nstruct _cffi_enum_s {\n    const char *name;\n    int type_index;          // -> _cffi_types, on a OP_ENUM\n    int type_prim;           // _CFFI_PRIM_xxx\n    const char *enumerators; // comma-delimited string\n};\n\nstruct _cffi_typename_s {\n    const char *name;\n    int type_index;   /* if opaque, points to a possibly artificial\n                         OP_STRUCT which is itself opaque */\n};\n\nstruct _cffi_type_context_s {\n    _cffi_opcode_t *types;\n    const struct _cffi_global_s *globals;\n    const struct _cffi_field_s *fields;\n    const struct _cffi_struct_union_s *struct_unions;\n    const struct _cffi_enum_s *enums;\n    const struct _cffi_typename_s *typenames;\n    int num_globals;\n    int num_struct_unions;\n    int num_enums;\n    int num_typenames;\n    const char *const *includes;\n    int num_types;\n    int flags;      /* future extension */\n};\n\nstruct _cffi_parse_info_s {\n    const struct _cffi_type_context_s *ctx;\n    _cffi_opcode_t *output;\n    unsigned int output_size;\n    size_t error_location;\n    const char *error_message;\n};\n\nstruct _cffi_externpy_s {\n    const char *name;\n    size_t size_of_result;\n    void *reserved1, *reserved2;\n};\n\n#ifdef _CFFI_INTERNAL\nstatic int parse_c_type(struct _cffi_parse_info_s *info, const char *input);\nstatic int search_in_globals(const struct _cffi_type_context_s *ctx,\n                             const char *search, size_t search_len);\nstatic int search_in_struct_unions(const struct _cffi_type_context_s *ctx,\n                                   const char *search, size_t search_len);\n#endif\n\n/* this block of #ifs should be kept exactly identical between\n   c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py\n   and cffi/_cffi_include.h */\n#if defined(_MSC_VER)\n# include <malloc.h>   /* for alloca() */\n# if _MSC_VER < 1600   /* MSVC < 2010 */\n   typedef __int8 int8_t;\n   typedef __int16 int16_t;\n   typedef __int32 int32_t;\n   typedef __int64 int64_t;\n   typedef unsigned __int8 uint8_t;\n   typedef unsigned __int16 uint16_t;\n   typedef unsigned __int32 uint32_t;\n   typedef unsigned __int64 uint64_t;\n   typedef __int8 int_least8_t;\n   typedef __int16 int_least16_t;\n   typedef __int32 int_least32_t;\n   typedef __int64 int_least64_t;\n   typedef unsigned __int8 uint_least8_t;\n   typedef unsigned __int16 uint_least16_t;\n   typedef unsigned __int32 uint_least32_t;\n   typedef unsigned __int64 uint_least64_t;\n   typedef __int8 int_fast8_t;\n   typedef __int16 int_fast16_t;\n   typedef __int32 int_fast32_t;\n   typedef __int64 int_fast64_t;\n   typedef unsigned __int8 uint_fast8_t;\n   typedef unsigned __int16 uint_fast16_t;\n   typedef unsigned __int32 uint_fast32_t;\n   typedef unsigned __int64 uint_fast64_t;\n   typedef __int64 intmax_t;\n   typedef unsigned __int64 uintmax_t;\n# else\n#  include <stdint.h>\n# endif\n# if _MSC_VER < 1800   /* MSVC < 2013 */\n#  ifndef __cplusplus\n    typedef unsigned char _Bool;\n#  endif\n# endif\n#else\n# include <stdint.h>\n# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)\n#  include <alloca.h>\n# endif\n#endif\n\n#ifdef __GNUC__\n# define _CFFI_UNUSED_FN  __attribute__((unused))\n#else\n# define _CFFI_UNUSED_FN  /* nothing */\n#endif\n\n#ifdef __cplusplus\n# ifndef _Bool\n   typedef bool _Bool;   /* semi-hackish: C++ has no _Bool; bool is builtin */\n# endif\n#endif\n\n/**********  CPython-specific section  **********/\n#ifndef PYPY_VERSION\n\n\n#if PY_MAJOR_VERSION >= 3\n# define PyInt_FromLong PyLong_FromLong\n#endif\n\n#define _cffi_from_c_double PyFloat_FromDouble\n#define _cffi_from_c_float PyFloat_FromDouble\n#define _cffi_from_c_long PyInt_FromLong\n#define _cffi_from_c_ulong PyLong_FromUnsignedLong\n#define _cffi_from_c_longlong PyLong_FromLongLong\n#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong\n#define _cffi_from_c__Bool PyBool_FromLong\n\n#define _cffi_to_c_double PyFloat_AsDouble\n#define _cffi_to_c_float PyFloat_AsDouble\n\n#define _cffi_from_c_int(x, type)                                        \\\n    (((type)-1) > 0 ? /* unsigned */                                     \\\n        (sizeof(type) < sizeof(long) ?                                   \\\n            PyInt_FromLong((long)x) :                                    \\\n         sizeof(type) == sizeof(long) ?                                  \\\n            PyLong_FromUnsignedLong((unsigned long)x) :                  \\\n            PyLong_FromUnsignedLongLong((unsigned long long)x)) :        \\\n        (sizeof(type) <= sizeof(long) ?                                  \\\n            PyInt_FromLong((long)x) :                                    \\\n            PyLong_FromLongLong((long long)x)))\n\n#define _cffi_to_c_int(o, type)                                          \\\n    ((type)(                                                             \\\n     sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o)        \\\n                                         : (type)_cffi_to_c_i8(o)) :     \\\n     sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o)       \\\n                                         : (type)_cffi_to_c_i16(o)) :    \\\n     sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o)       \\\n                                         : (type)_cffi_to_c_i32(o)) :    \\\n     sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o)       \\\n                                         : (type)_cffi_to_c_i64(o)) :    \\\n     (Py_FatalError(\"unsupported size for type \" #type), (type)0)))\n\n#define _cffi_to_c_i8                                                    \\\n                 ((int(*)(PyObject *))_cffi_exports[1])\n#define _cffi_to_c_u8                                                    \\\n                 ((int(*)(PyObject *))_cffi_exports[2])\n#define _cffi_to_c_i16                                                   \\\n                 ((int(*)(PyObject *))_cffi_exports[3])\n#define _cffi_to_c_u16                                                   \\\n                 ((int(*)(PyObject *))_cffi_exports[4])\n#define _cffi_to_c_i32                                                   \\\n                 ((int(*)(PyObject *))_cffi_exports[5])\n#define _cffi_to_c_u32                                                   \\\n                 ((unsigned int(*)(PyObject *))_cffi_exports[6])\n#define _cffi_to_c_i64                                                   \\\n                 ((long long(*)(PyObject *))_cffi_exports[7])\n#define _cffi_to_c_u64                                                   \\\n                 ((unsigned long long(*)(PyObject *))_cffi_exports[8])\n#define _cffi_to_c_char                                                  \\\n                 ((int(*)(PyObject *))_cffi_exports[9])\n#define _cffi_from_c_pointer                                             \\\n    ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[10])\n#define _cffi_to_c_pointer                                               \\\n    ((char *(*)(PyObject *, struct _cffi_ctypedescr *))_cffi_exports[11])\n#define _cffi_get_struct_layout                                          \\\n    not used any more\n#define _cffi_restore_errno                                              \\\n    ((void(*)(void))_cffi_exports[13])\n#define _cffi_save_errno                                                 \\\n    ((void(*)(void))_cffi_exports[14])\n#define _cffi_from_c_char                                                \\\n    ((PyObject *(*)(char))_cffi_exports[15])\n#define _cffi_from_c_deref                                               \\\n    ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[16])\n#define _cffi_to_c                                                       \\\n    ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[17])\n#define _cffi_from_c_struct                                              \\\n    ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[18])\n#define _cffi_to_c_wchar_t                                               \\\n    ((_cffi_wchar_t(*)(PyObject *))_cffi_exports[19])\n#define _cffi_from_c_wchar_t                                             \\\n    ((PyObject *(*)(_cffi_wchar_t))_cffi_exports[20])\n#define _cffi_to_c_long_double                                           \\\n    ((long double(*)(PyObject *))_cffi_exports[21])\n#define _cffi_to_c__Bool                                                 \\\n    ((_Bool(*)(PyObject *))_cffi_exports[22])\n#define _cffi_prepare_pointer_call_argument                              \\\n    ((Py_ssize_t(*)(struct _cffi_ctypedescr *,                           \\\n                    PyObject *, char **))_cffi_exports[23])\n#define _cffi_convert_array_from_object                                  \\\n    ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[24])\n#define _CFFI_CPIDX  25\n#define _cffi_call_python                                                \\\n    ((void(*)(struct _cffi_externpy_s *, char *))_cffi_exports[_CFFI_CPIDX])\n#define _cffi_to_c_wchar3216_t                                           \\\n    ((int(*)(PyObject *))_cffi_exports[26])\n#define _cffi_from_c_wchar3216_t                                         \\\n    ((PyObject *(*)(int))_cffi_exports[27])\n#define _CFFI_NUM_EXPORTS 28\n\nstruct _cffi_ctypedescr;\n\nstatic void *_cffi_exports[_CFFI_NUM_EXPORTS];\n\n#define _cffi_type(index)   (                           \\\n    assert((((uintptr_t)_cffi_types[index]) & 1) == 0), \\\n    (struct _cffi_ctypedescr *)_cffi_types[index])\n\nstatic PyObject *_cffi_init(const char *module_name, Py_ssize_t version,\n                            const struct _cffi_type_context_s *ctx)\n{\n    PyObject *module, *o_arg, *new_module;\n    void *raw[] = {\n        (void *)module_name,\n        (void *)version,\n        (void *)_cffi_exports,\n        (void *)ctx,\n    };\n\n    module = PyImport_ImportModule(\"_cffi_backend\");\n    if (module == NULL)\n        goto failure;\n\n    o_arg = PyLong_FromVoidPtr((void *)raw);\n    if (o_arg == NULL)\n        goto failure;\n\n    new_module = PyObject_CallMethod(\n        module, (char *)\"_init_cffi_1_0_external_module\", (char *)\"O\", o_arg);\n\n    Py_DECREF(o_arg);\n    Py_DECREF(module);\n    return new_module;\n\n  failure:\n    Py_XDECREF(module);\n    return NULL;\n}\n\n\n#ifdef HAVE_WCHAR_H\ntypedef wchar_t _cffi_wchar_t;\n#else\ntypedef uint16_t _cffi_wchar_t;   /* same random pick as _cffi_backend.c */\n#endif\n\n_CFFI_UNUSED_FN static uint16_t _cffi_to_c_char16_t(PyObject *o)\n{\n    if (sizeof(_cffi_wchar_t) == 2)\n        return (uint16_t)_cffi_to_c_wchar_t(o);\n    else\n        return (uint16_t)_cffi_to_c_wchar3216_t(o);\n}\n\n_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char16_t(uint16_t x)\n{\n    if (sizeof(_cffi_wchar_t) == 2)\n        return _cffi_from_c_wchar_t((_cffi_wchar_t)x);\n    else\n        return _cffi_from_c_wchar3216_t((int)x);\n}\n\n_CFFI_UNUSED_FN static int _cffi_to_c_char32_t(PyObject *o)\n{\n    if (sizeof(_cffi_wchar_t) == 4)\n        return (int)_cffi_to_c_wchar_t(o);\n    else\n        return (int)_cffi_to_c_wchar3216_t(o);\n}\n\n_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char32_t(int x)\n{\n    if (sizeof(_cffi_wchar_t) == 4)\n        return _cffi_from_c_wchar_t((_cffi_wchar_t)x);\n    else\n        return _cffi_from_c_wchar3216_t(x);\n}\n\n\n/**********  end CPython-specific section  **********/\n#else\n_CFFI_UNUSED_FN\nstatic void (*_cffi_call_python_org)(struct _cffi_externpy_s *, char *);\n# define _cffi_call_python  _cffi_call_python_org\n#endif\n\n\n#define _cffi_array_len(array)   (sizeof(array) / sizeof((array)[0]))\n\n#define _cffi_prim_int(size, sign)                                      \\\n    ((size) == 1 ? ((sign) ? _CFFI_PRIM_INT8  : _CFFI_PRIM_UINT8)  :    \\\n     (size) == 2 ? ((sign) ? _CFFI_PRIM_INT16 : _CFFI_PRIM_UINT16) :    \\\n     (size) == 4 ? ((sign) ? _CFFI_PRIM_INT32 : _CFFI_PRIM_UINT32) :    \\\n     (size) == 8 ? ((sign) ? _CFFI_PRIM_INT64 : _CFFI_PRIM_UINT64) :    \\\n     _CFFI__UNKNOWN_PRIM)\n\n#define _cffi_prim_float(size)                                          \\\n    ((size) == sizeof(float) ? _CFFI_PRIM_FLOAT :                       \\\n     (size) == sizeof(double) ? _CFFI_PRIM_DOUBLE :                     \\\n     (size) == sizeof(long double) ? _CFFI__UNKNOWN_LONG_DOUBLE :       \\\n     _CFFI__UNKNOWN_FLOAT_PRIM)\n\n#define _cffi_check_int(got, got_nonpos, expected)      \\\n    ((got_nonpos) == (expected <= 0) &&                 \\\n     (got) == (unsigned long long)expected)\n\n#ifdef MS_WIN32\n# define _cffi_stdcall  __stdcall\n#else\n# define _cffi_stdcall  /* nothing */\n#endif\n\n#ifdef __cplusplus\n}\n#endif\n\n/************************************************************/\n\n\n    #include \"src/ext.h\"\n    \n\n/************************************************************/\n\nstatic void *_cffi_types[] = {\n/*  0 */ _CFFI_OP(_CFFI_OP_FUNCTION, 11), // void()(float *, float *, int *, int, int, int, int, int)\n/*  1 */ _CFFI_OP(_CFFI_OP_POINTER, 10), // float *\n/*  2 */ _CFFI_OP(_CFFI_OP_NOOP, 1),\n/*  3 */ _CFFI_OP(_CFFI_OP_POINTER, 4), // int *\n/*  4 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 7), // int\n/*  5 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 7),\n/*  6 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 7),\n/*  7 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 7),\n/*  8 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 7),\n/*  9 */ _CFFI_OP(_CFFI_OP_FUNCTION_END, 0),\n/* 10 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 13), // float\n/* 11 */ _CFFI_OP(_CFFI_OP_PRIMITIVE, 0), // void\n};\n\nstatic void _cffi_d_findNearestPointIdxLauncher(float * x0, float * x1, int * x2, int x3, int x4, int x5, int x6, int x7)\n{\n  findNearestPointIdxLauncher(x0, x1, x2, x3, x4, x5, x6, x7);\n}\n#ifndef PYPY_VERSION\nstatic PyObject *\n_cffi_f_findNearestPointIdxLauncher(PyObject *self, PyObject *args)\n{\n  float * x0;\n  float * x1;\n  int * x2;\n  int x3;\n  int x4;\n  int x5;\n  int x6;\n  int x7;\n  Py_ssize_t datasize;\n  PyObject *arg0;\n  PyObject *arg1;\n  PyObject *arg2;\n  PyObject *arg3;\n  PyObject *arg4;\n  PyObject *arg5;\n  PyObject *arg6;\n  PyObject *arg7;\n\n  if (!PyArg_UnpackTuple(args, \"findNearestPointIdxLauncher\", 8, 8, &arg0, &arg1, &arg2, &arg3, &arg4, &arg5, &arg6, &arg7))\n    return NULL;\n\n  datasize = _cffi_prepare_pointer_call_argument(\n      _cffi_type(1), arg0, (char **)&x0);\n  if (datasize != 0) {\n    if (datasize < 0)\n      return NULL;\n    x0 = (float *)alloca((size_t)datasize);\n    memset((void *)x0, 0, (size_t)datasize);\n    if (_cffi_convert_array_from_object((char *)x0, _cffi_type(1), arg0) < 0)\n      return NULL;\n  }\n\n  datasize = _cffi_prepare_pointer_call_argument(\n      _cffi_type(1), arg1, (char **)&x1);\n  if (datasize != 0) {\n    if (datasize < 0)\n      return NULL;\n    x1 = (float *)alloca((size_t)datasize);\n    memset((void *)x1, 0, (size_t)datasize);\n    if (_cffi_convert_array_from_object((char *)x1, _cffi_type(1), arg1) < 0)\n      return NULL;\n  }\n\n  datasize = _cffi_prepare_pointer_call_argument(\n      _cffi_type(3), arg2, (char **)&x2);\n  if (datasize != 0) {\n    if (datasize < 0)\n      return NULL;\n    x2 = (int *)alloca((size_t)datasize);\n    memset((void *)x2, 0, (size_t)datasize);\n    if (_cffi_convert_array_from_object((char *)x2, _cffi_type(3), arg2) < 0)\n      return NULL;\n  }\n\n  x3 = _cffi_to_c_int(arg3, int);\n  if (x3 == (int)-1 && PyErr_Occurred())\n    return NULL;\n\n  x4 = _cffi_to_c_int(arg4, int);\n  if (x4 == (int)-1 && PyErr_Occurred())\n    return NULL;\n\n  x5 = _cffi_to_c_int(arg5, int);\n  if (x5 == (int)-1 && PyErr_Occurred())\n    return NULL;\n\n  x6 = _cffi_to_c_int(arg6, int);\n  if (x6 == (int)-1 && PyErr_Occurred())\n    return NULL;\n\n  x7 = _cffi_to_c_int(arg7, int);\n  if (x7 == (int)-1 && PyErr_Occurred())\n    return NULL;\n\n  Py_BEGIN_ALLOW_THREADS\n  _cffi_restore_errno();\n  { findNearestPointIdxLauncher(x0, x1, x2, x3, x4, x5, x6, x7); }\n  _cffi_save_errno();\n  Py_END_ALLOW_THREADS\n\n  (void)self; /* unused */\n  Py_INCREF(Py_None);\n  return Py_None;\n}\n#else\n#  define _cffi_f_findNearestPointIdxLauncher _cffi_d_findNearestPointIdxLauncher\n#endif\n\nstatic const struct _cffi_global_s _cffi_globals[] = {\n  { \"findNearestPointIdxLauncher\", (void *)_cffi_f_findNearestPointIdxLauncher, _CFFI_OP(_CFFI_OP_CPYTHON_BLTN_V, 0), (void *)_cffi_d_findNearestPointIdxLauncher },\n};\n\nstatic const struct _cffi_type_context_s _cffi_type_context = {\n  _cffi_types,\n  _cffi_globals,\n  NULL,  /* no fields */\n  NULL,  /* no struct_unions */\n  NULL,  /* no enums */\n  NULL,  /* no typenames */\n  1,  /* num_globals */\n  0,  /* num_struct_unions */\n  0,  /* num_enums */\n  0,  /* num_typenames */\n  NULL,  /* no includes */\n  12,  /* num_types */\n  0,  /* flags */\n};\n\n#ifdef __GNUC__\n#  pragma GCC visibility push(default)  /* for -fvisibility= */\n#endif\n\n#ifdef PYPY_VERSION\nPyMODINIT_FUNC\n_cffi_pypyinit__ext(const void *p[])\n{\n    p[0] = (const void *)0x2601;\n    p[1] = &_cffi_type_context;\n#if PY_MAJOR_VERSION >= 3\n    return NULL;\n#endif\n}\n#  ifdef _MSC_VER\n     PyMODINIT_FUNC\n#  if PY_MAJOR_VERSION >= 3\n     PyInit__ext(void) { return NULL; }\n#  else\n     init_ext(void) { }\n#  endif\n#  endif\n#elif PY_MAJOR_VERSION >= 3\nPyMODINIT_FUNC\nPyInit__ext(void)\n{\n  return _cffi_init(\"_ext\", 0x2601, &_cffi_type_context);\n}\n#else\nPyMODINIT_FUNC\ninit_ext(void)\n{\n  _cffi_init(\"_ext\", 0x2601, &_cffi_type_context);\n}\n#endif\n\n#ifdef __GNUC__\n#  pragma GCC visibility pop\n#endif\n"
  },
  {
    "path": "thirdparty/nn/nn_utils.py",
    "content": "# from lib.csrc.nn._ext import lib, ffi\nfrom thirdparty.nn._ext import lib, ffi\nimport numpy as np\n\n\ndef find_nearest_point_idx(ref_pts, que_pts):\n    assert(ref_pts.shape[1] == que_pts.shape[1] and 1 < que_pts.shape[1] <= 3)\n    pn1 = ref_pts.shape[0]\n    pn2 = que_pts.shape[0]\n    dim = ref_pts.shape[1]\n\n    ref_pts = np.ascontiguousarray(ref_pts[None,:,:], np.float32)\n    que_pts = np.ascontiguousarray(que_pts[None,:,:], np.float32)\n    idxs = np.zeros([1, pn2], np.int32)\n\n    ref_pts_ptr = ffi.cast('float *', ref_pts.ctypes.data)\n    que_pts_ptr = ffi.cast('float *', que_pts.ctypes.data)\n    idxs_ptr = ffi.cast('int *', idxs.ctypes.data)\n    lib.findNearestPointIdxLauncher(ref_pts_ptr, que_pts_ptr, idxs_ptr, 1, pn1, pn2, dim, 0)\n\n    return idxs[0]\n"
  },
  {
    "path": "thirdparty/nn/setup.py",
    "content": "import os\n\ncuda_include=os.path.join(os.environ.get('CUDA_HOME'), 'include')\nos.system('nvcc src/nearest_neighborhood.cu -c -o src/nearest_neighborhood.cu.o -x cu -Xcompiler -fPIC -O2 -arch=sm_52 -I {}'.format(cuda_include))\n\nfrom cffi import FFI\nffibuilder = FFI()\n\n\nwith open(os.path.join(os.path.dirname(__file__), \"src/ext.h\")) as f:\n    ffibuilder.cdef(f.read())\n\nffibuilder.set_source(\n    \"_ext\",\n    \"\"\"\n    #include \"src/ext.h\"\n    \"\"\",\n    extra_objects=['src/nearest_neighborhood.cu.o',\n                   os.path.join(os.environ.get('CUDA_HOME'),'lib64/libcudart.so')],\n    libraries=['stdc++']\n)\n\n\nif __name__ == \"__main__\":\n    ffibuilder.compile(verbose=True)\n    os.system(\"rm src/*.o\")\n    os.system(\"rm *.o\")\n"
  },
  {
    "path": "thirdparty/nn/src/ext.h",
    "content": "void findNearestPointIdxLauncher(\n    float* ref_pts,   // [b,pn1,dim]\n    float* que_pts,   // [b,pn2,dim]\n    int* idxs,        // [b,pn2]\n    int b,\n    int pn1,\n    int pn2,\n    int dim,\n    int exclude_self\n);\n"
  },
  {
    "path": "thirdparty/nn/src/nearest_neighborhood.cu",
    "content": "#include <float.h>\n#include <stdio.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_runtime_api.h>\n#include <stdio.h>\n\n#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }\n\nvoid gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)\n{\n    if (code != cudaSuccess)\n    {\n        fprintf(stderr,\"GPUassert: %s %s %d\\n\", cudaGetErrorString(code), file, line);\n        if (abort) exit(code);\n    }\n}\n\nint infTwoExp(int val)\n{\n    int inf=1;\n    while(val>inf) inf<<=1;\n    return inf;\n}\n\nvoid getGPULayout(\n        int dim0,int dim1,int dim2,\n        int* bdim0,int* bdim1,int* bdim2,\n        int* tdim0,int* tdim1,int* tdim2\n)\n{\n    (*tdim2)=64;\n    if(dim2<(*tdim2)) (*tdim2)=infTwoExp(dim2);\n    (*bdim2)=dim2/(*tdim2);\n    if(dim2%(*tdim2)>0) (*bdim2)++;\n\n    (*tdim1)=1024/(*tdim2);\n    if(dim1<(*tdim1)) (*tdim1)=infTwoExp(dim1);\n    (*bdim1)=dim1/(*tdim1);\n    if(dim1%(*tdim1)>0) (*bdim1)++;\n\n    (*tdim0)=1024/((*tdim1)*(*tdim2));\n    if(dim0<(*tdim0)) (*tdim0)=infTwoExp(dim0);\n    (*bdim0)=dim0/(*tdim0);\n    if(dim0%(*tdim0)>0) (*bdim0)++;\n}\n\n__global__\nvoid findNearestPoint3DIdxKernel(\n    float* ref_pts,   // [b,pn1,3]\n    float* que_pts,   // [b,pn2,3]\n    int* idxs,        // [b,pn2]\n    int b,\n    int pn1,\n    int pn2,\n    int exclude_self\n)\n{\n    int bi = threadIdx.x + blockIdx.x*blockDim.x;\n    int p2i = threadIdx.y + blockIdx.y*blockDim.y;\n    if(p2i>=pn2||bi>=b) return;\n\n    float x2=que_pts[bi*pn2*3+p2i*3];\n    float y2=que_pts[bi*pn2*3+p2i*3+1];\n    float z2=que_pts[bi*pn2*3+p2i*3+2];\n    float min_dist=FLT_MAX;\n    int min_idx=0;\n    for(int p1i=0;p1i<pn1;p1i++)\n    {\n        if(exclude_self&&p1i==p2i) continue;\n        float x1=ref_pts[bi*pn1*3+p1i*3];\n        float y1=ref_pts[bi*pn1*3+p1i*3+1];\n        float z1=ref_pts[bi*pn1*3+p1i*3+2];\n\n        float dist=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2);\n        if(dist<min_dist)\n        {\n            min_dist=dist;\n            min_idx=p1i;\n        }\n    }\n    idxs[bi*pn2+p2i]=min_idx;\n}\n__global__\nvoid findNearestPoint2DIdxKernel(\n    float* ref_pts,   // [b,pn1,2]\n    float* que_pts,   // [b,pn2,2]\n    int* idxs,        // [b,pn2]\n    int b,\n    int pn1,\n    int pn2,\n    int exclude_self\n)\n{\n    int bi = threadIdx.x + blockIdx.x*blockDim.x;\n    int p2i = threadIdx.y + blockIdx.y*blockDim.y;\n    if(p2i>=pn2||bi>=b) return;\n\n    float x2=que_pts[bi*pn2*2+p2i*2];\n    float y2=que_pts[bi*pn2*2+p2i*2+1];\n    float min_dist=FLT_MAX;\n    int min_idx=0;\n    for(int p1i=0;p1i<pn1;p1i++)\n    {\n        if(exclude_self&&p1i==p2i) continue;\n        float x1=ref_pts[bi*pn1*2+p1i*2];\n        float y1=ref_pts[bi*pn1*2+p1i*2+1];\n\n        float dist=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2);\n        if(dist<min_dist)\n        {\n            min_dist=dist;\n            min_idx=p1i;\n        }\n    }\n    idxs[bi*pn2+p2i]=min_idx;\n}\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nvoid findNearestPointIdxLauncher(\n    float* ref_pts,   // [b,pn1,dim]\n    float* que_pts,   // [b,pn2,dim]\n    int* idxs,        // [b,pn2]\n    int b,\n    int pn1,\n    int pn2,\n    int dim,\n    int exclude_self\n)\n{\n    float* ref_pts_dev,* que_pts_dev;\n    int* idxs_dev;\n    gpuErrchk(cudaMalloc(&ref_pts_dev,b*pn1*sizeof(float)*dim))\n    gpuErrchk(cudaMalloc(&que_pts_dev,b*pn2*sizeof(float)*dim))\n    gpuErrchk(cudaMalloc(&idxs_dev,b*pn2*sizeof(int)))\n\n    gpuErrchk(cudaMemcpy(ref_pts_dev,ref_pts,b*pn1*sizeof(float)*dim,cudaMemcpyHostToDevice))\n    gpuErrchk(cudaMemcpy(que_pts_dev,que_pts,b*pn2*sizeof(float)*dim,cudaMemcpyHostToDevice))\n    gpuErrchk(cudaMemcpy(idxs_dev,idxs,b*pn2*sizeof(int),cudaMemcpyHostToDevice))\n\n    int bdim0,bdim1,bdim2;\n    int tdim0,tdim1,tdim2;\n\n    getGPULayout(b,pn2,1,&bdim0,&bdim1,&bdim2,&tdim0,&tdim1,&tdim2);\n\n    dim3 bdim(bdim0,bdim1,bdim2);\n    dim3 tdim(tdim0,tdim1,tdim2);\n\n    if(dim==3)\n        findNearestPoint3DIdxKernel<<<bdim,tdim>>>(ref_pts_dev,que_pts_dev,idxs_dev,b,pn1,pn2,exclude_self);\n    else\n        findNearestPoint2DIdxKernel<<<bdim,tdim>>>(ref_pts_dev,que_pts_dev,idxs_dev,b,pn1,pn2,exclude_self);\n    gpuErrchk(cudaGetLastError())\n\n    gpuErrchk(cudaMemcpy(idxs,idxs_dev,b*pn2*sizeof(int),cudaMemcpyDeviceToHost))\n    gpuErrchk(cudaFree(ref_pts_dev))\n    gpuErrchk(cudaFree(que_pts_dev))\n    gpuErrchk(cudaFree(idxs_dev))\n\n}\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "thirdparty/raft/corr.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt_cuda_corr\nexcept:\n    # alt_cuda_corr is not compiled\n    pass\n\n\nclass CorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4, downsample_rate=1):\n        self.num_levels = num_levels\n        self.radius = radius\n        self.corr_pyramid = []\n\n        # all pairs correlation\n        corr = CorrBlock.corr(fmap1, fmap2)\n        if downsample_rate>1:\n            batch, h1, w1, dim, h2, w2 = corr.shape\n            corr=torch.nn.MaxPool2d(corr.reshape(batch, -1, dim, h2,w2), downsample_rate, stride=downsample_rate)\n            corr=torch.nn.MaxPool2d(corr.reshape(batch, h1,w1, dim, -1).permute(0,4,3,1,2), downsample_rate, stride=downsample_rate)\n            corr = corr.permute(0,3,4,3,1).reshape(batch,h1//downsample_rate, w1//downsample_rate, dim, h2//downsample_rate, w2//downsample_rate)\n\n\n\n        batch, h1, w1, dim, h2, w2 = corr.shape\n        corr = corr.reshape(batch*h1*w1, dim, h2, w2)\n        \n        self.corr_pyramid.append(corr)\n        for i in range(self.num_levels-1):\n            corr = F.avg_pool2d(corr, 2, stride=2)\n            self.corr_pyramid.append(corr)\n\n    def __call__(self, coords):\n        r = self.radius\n        coords = coords.permute(0, 2, 3, 1)\n        batch, h1, w1, _ = coords.shape\n\n        out_pyramid = []\n        for i in range(self.num_levels):\n            corr = self.corr_pyramid[i]\n            dx = torch.linspace(-r, r, 2*r+1)\n            dy = torch.linspace(-r, r, 2*r+1)\n            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)\n\n            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i\n            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)\n            coords_lvl = centroid_lvl + delta_lvl\n\n            corr = bilinear_sampler(corr, coords_lvl)\n            corr = corr.view(batch, h1, w1, -1)\n            out_pyramid.append(corr)\n\n        out = torch.cat(out_pyramid, dim=-1)\n        return out.permute(0, 3, 1, 2).contiguous().float()\n\n    @staticmethod\n    def corr(fmap1, fmap2):\n        batch, dim, ht, wd = fmap1.shape\n        fmap1 = fmap1.view(batch, dim, ht*wd)\n        fmap2 = fmap2.view(batch, dim, ht*wd) \n        \n        corr = torch.matmul(fmap1.transpose(1,2), fmap2)\n        corr = corr.view(batch, ht, wd, 1, ht, wd)\n        return corr  / torch.sqrt(torch.tensor(dim).float())\n\n\nclass AlternateCorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n\n        self.pyramid = [(fmap1, fmap2)]\n        for i in range(self.num_levels):\n            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)\n            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)\n            self.pyramid.append((fmap1, fmap2))\n\n    def __call__(self, coords):\n        coords = coords.permute(0, 2, 3, 1)\n        B, H, W, _ = coords.shape\n        dim = self.pyramid[0][0].shape[1]\n\n        corr_list = []\n        for i in range(self.num_levels):\n            r = self.radius\n            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()\n            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()\n\n            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()\n            corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)\n            corr_list.append(corr.squeeze(1))\n\n        corr = torch.stack(corr_list, dim=1)\n        corr = corr.reshape(B, -1, H, W)\n        return corr / torch.sqrt(torch.tensor(dim).float())\n"
  },
  {
    "path": "thirdparty/raft/extractor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(ResidualBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\n\n\nclass BottleneckBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(BottleneckBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)\n        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)\n        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes//4)\n            self.norm2 = nn.BatchNorm2d(planes//4)\n            self.norm3 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes//4)\n            self.norm2 = nn.InstanceNorm2d(planes//4)\n            self.norm3 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            self.norm3 = nn.Sequential()\n            if not stride == 1:\n                self.norm4 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n        y = self.relu(self.norm3(self.conv3(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, input_dim=3, with_decoder=False, decoder_dim=0):\n        super(BasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64,  stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.with_decoder = with_decoder\n        if with_decoder:\n            self.dec_layer3 = nn.Sequential(\n                            nn.Conv2d(128, 96, 3, 1, 1, bias=False),\n                            nn.BatchNorm2d(96),\n                            self.relu1,\n                            nn.UpsamplingBilinear2d(scale_factor=2),\n            )\n            self.dec_layer2 = nn.Sequential(\n                            nn.Conv2d(96+96, 64, 3, 1, 1, bias=False),\n                            nn.BatchNorm2d(64),\n                            self.relu1,\n                            nn.UpsamplingBilinear2d(scale_factor=2),\n            )\n            self.dec_layer1 = nn.Sequential(\n                            nn.Conv2d(64+64, 64, 3, 1, 1, bias=False),\n                            nn.BatchNorm2d(64),\n                            self.relu1,\n                            nn.UpsamplingBilinear2d(scale_factor=2),\n            )\n            self.reg_layer = nn.Conv2d(64, decoder_dim, 1, 1, 0, bias=False)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n        \n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n        \n        ups=[]\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n        ups.append(x) #added\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        ups.append(x) #added\n        x = self.layer3(x)\n\n        if self.with_decoder:\n            dec_x=self.dec_layer3(x)\n            dec_x=self.dec_layer2(torch.cat( [dec_x, ups[-1]], dim=1 ) )\n            dec_x=self.dec_layer1(torch.cat([dec_x, ups[-2]], dim=1 ) )\n            dec_x = self.reg_layer(dec_x)\n\n        x = self.conv2(x)\n\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n            if self.with_decoder:\n                dec_x = torch.split(dec_x, [batch_dim, batch_dim], dim=0)\n        if self.with_decoder:\n            return x, dec_x\n        else:\n            return x\n\nclass BasicEncoder_dx4(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(BasicEncoder_dx4, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64,  stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\n\nclass SmallEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(SmallEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(32)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(32)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 32\n        self.layer1 = self._make_layer(32,  stride=1)\n        self.layer2 = self._make_layer(64, stride=2)\n        self.layer3 = self._make_layer(96, stride=2)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n    \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\nclass SmallEncoder_dx4(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(SmallEncoder_dx4, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(32)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(32)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        # self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 32\n        self.layer1 = self._make_layer(32,  stride=1)\n        self.layer2 = self._make_layer(64, stride=2)\n        self.layer3 = self._make_layer(96, stride=2)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n    \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x"
  },
  {
    "path": "thirdparty/raft/update.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, input_dim=128, hidden_dim=256):\n        super(FlowHead, self).__init__()\n        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)\n        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        return self.conv2(self.relu(self.conv1(x)))\n\nclass ConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(ConvGRU, self).__init__()\n        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n\n    def forward(self, h, x):\n        hx = torch.cat([h, x], dim=1)\n\n        z = torch.sigmoid(self.convz(hx))\n        r = torch.sigmoid(self.convr(hx))\n        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))\n\n        h = (1-z) * h + z * q\n        return h\n\nclass SepConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(SepConvGRU, self).__init__()\n        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n\n        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n\n\n    def forward(self, h, x):\n        # horizontal\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        \n        h = (1-z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       \n        h = (1-z) * h + z * q\n\n        return h\n\nclass SmallMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(SmallMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)\n        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)\n        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)\n        self.conv = nn.Conv2d(128, 80, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass BasicMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(BasicMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)\n        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)\n        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass BasicMotionEncoderGeo(nn.Module):\n    def __init__(self, args):\n        super(BasicMotionEncoderGeo, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)\n        self.convc1_geo = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2_geo = nn.Conv2d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)\n        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)\n        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)\n\n    def forward(self, flow, corr, corr_geo):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        \n        cor_geo = F.relu(self.convc1_geo(corr_geo))\n        cor_geo = F.relu(self.convc2_geo(cor_geo))\n        # cor_geo = F.leaky_relu(self.convc1_geo(corr_geo))\n        # cor_geo = F.leaky_relu(self.convc2_geo(cor_geo))\n\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        # cor_flo = torch.cat([cor, flo], dim=1)\n        cor_flo = torch.cat([cor+cor_geo, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass SmallUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=96):\n        super(SmallUpdateBlock, self).__init__()\n        self.encoder = SmallMotionEncoder(args)\n        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)\n\n    def forward(self, net, inp, corr, flow):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        return net, None, delta_flow\n\nclass SmallUpdateBlockUpMask(nn.Module):\n    def __init__(self, args, hidden_dim=96):\n        super(SmallUpdateBlockUpMask, self).__init__()\n        self.encoder = SmallMotionEncoder(args)\n        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)\n        self.mask = nn.Sequential(\n            nn.Conv2d(96, 192, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(192, 64*9, 1, padding=0))\n\n    def forward(self, net, inp, corr, flow):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n        mask = .25 * self.mask(net)\n\n        # return net, None, delta_flow\n        return net, mask, delta_flow\n\nclass BasicUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=128, input_dim=128, downsample_scale=8):\n        super(BasicUpdateBlock, self).__init__()\n        self.args = args\n        self.encoder = BasicMotionEncoder(args)\n        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)\n\n        self.mask = nn.Sequential(\n            nn.Conv2d(128, 256, 3, padding=1),\n            nn.ReLU(inplace=True),\n            # nn.Conv2d(256, 64*9, 1, padding=0))\n            nn.Conv2d(256, downsample_scale*downsample_scale*9, 1, padding=0))\n\n    def forward(self, net, inp, corr, flow, upsample=True):\n        \n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balence gradients\n        mask = .25 * self.mask(net)\n        return net, mask, delta_flow\n\nclass BasicUpdateBlockGeo(nn.Module):\n    #add geo corr input\n    def __init__(self, args, hidden_dim=128, input_dim=128, downsample_scale=8):\n        super(BasicUpdateBlockGeo, self).__init__()\n        self.args = args\n        self.encoder = BasicMotionEncoderGeo(args)\n        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)\n\n        self.mask = nn.Sequential(\n            nn.Conv2d(128, 256, 3, padding=1),\n            nn.ReLU(inplace=True),\n            # nn.Conv2d(256, 64*9, 1, padding=0))\n            nn.Conv2d(256, downsample_scale*downsample_scale*9, 1, padding=0))\n\n    def forward(self, net, inp, corr, geo_corr, flow, upsample=True):\n        \n        motion_features = self.encoder(flow, corr, geo_corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balence gradients\n        mask = .25 * self.mask(net)\n        return net, mask, delta_flow\n\n"
  },
  {
    "path": "thirdparty/raft/utils/__init__.py",
    "content": ""
  },
  {
    "path": "thirdparty/raft/utils/augmentor.py",
    "content": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nimport torch\nfrom torchvision.transforms import ColorJitter\nimport torch.nn.functional as F\n\n\nclass FlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):\n        \n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n\n    def color_transform(self, img1, img2):\n        \"\"\" Photometric augmentation \"\"\"\n\n        # asymmetric\n        if np.random.rand() < self.asymmetric_color_aug_prob:\n            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)\n            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)\n\n        # symmetric\n        else:\n            image_stack = np.concatenate([img1, img2], axis=0)\n            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n            img1, img2 = np.split(image_stack, 2, axis=0)\n\n        return img1, img2\n\n    def eraser_transform(self, img1, img2, bounds=[50, 100]):\n        \"\"\" Occlusion augmentation \"\"\"\n\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(bounds[0], bounds[1])\n                dy = np.random.randint(bounds[0], bounds[1])\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def spatial_transform(self, img1, img2, flow):\n        # randomly sample scale\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 8) / float(ht), \n            (self.crop_size[1] + 8) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = scale\n        scale_y = scale\n        if np.random.rand() < self.stretch_prob:\n            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n        \n        scale_x = np.clip(scale_x, min_scale, None)\n        scale_y = np.clip(scale_y, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = flow * [scale_x, scale_y]\n\n        if self.do_flip:\n            if np.random.rand() < self.h_flip_prob: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n\n            if np.random.rand() < self.v_flip_prob: # v-flip\n                img1 = img1[::-1, :]\n                img2 = img2[::-1, :]\n                flow = flow[::-1, :] * [1.0, -1.0]\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])\n        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])\n        \n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n\n        return img1, img2, flow\n\n    def __call__(self, img1, img2, flow):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow = self.spatial_transform(img1, img2, flow)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n\n        return img1, img2, flow\n\nclass SparseFlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):\n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n        \n    def color_transform(self, img1, img2):\n        image_stack = np.concatenate([img1, img2], axis=0)\n        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n        img1, img2 = np.split(image_stack, 2, axis=0)\n        return img1, img2\n\n    def eraser_transform(self, img1, img2):\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(50, 100)\n                dy = np.random.randint(50, 100)\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):\n        ht, wd = flow.shape[:2]\n        coords = np.meshgrid(np.arange(wd), np.arange(ht))\n        coords = np.stack(coords, axis=-1)\n\n        coords = coords.reshape(-1, 2).astype(np.float32)\n        flow = flow.reshape(-1, 2).astype(np.float32)\n        valid = valid.reshape(-1).astype(np.float32)\n\n        coords0 = coords[valid>=1]\n        flow0 = flow[valid>=1]\n\n        ht1 = int(round(ht * fy))\n        wd1 = int(round(wd * fx))\n\n        coords1 = coords0 * [fx, fy]\n        flow1 = flow0 * [fx, fy]\n\n        xx = np.round(coords1[:,0]).astype(np.int32)\n        yy = np.round(coords1[:,1]).astype(np.int32)\n\n        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)\n        xx = xx[v]\n        yy = yy[v]\n        flow1 = flow1[v]\n\n        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)\n        valid_img = np.zeros([ht1, wd1], dtype=np.int32)\n\n        flow_img[yy, xx] = flow1\n        valid_img[yy, xx] = 1\n\n        return flow_img, valid_img\n\n    def spatial_transform(self, img1, img2, flow, valid):\n        # randomly sample scale\n\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 1) / float(ht), \n            (self.crop_size[1] + 1) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = np.clip(scale, min_scale, None)\n        scale_y = np.clip(scale, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)\n\n        if self.do_flip:\n            if np.random.rand() < 0.5: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n                valid = valid[:, ::-1]\n\n        margin_y = 20\n        margin_x = 50\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)\n        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)\n\n        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])\n        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])\n\n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        return img1, img2, flow, valid\n\n\n    def __call__(self, img1, img2, flow, valid):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n        valid = np.ascontiguousarray(valid)\n\n        return img1, img2, flow, valid\n"
  },
  {
    "path": "thirdparty/raft/utils/flow_viz.py",
    "content": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright (c) 2018 Tom Runia\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to conditions.\n#\n# Author: Tom Runia\n# Date Created: 2018-08-03\n\nimport numpy as np\n\ndef make_colorwheel():\n    \"\"\"\n    Generates a color wheel for optical flow visualization as presented in:\n        Baker et al. \"A Database and Evaluation Methodology for Optical Flow\" (ICCV, 2007)\n        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf\n\n    Code follows the original C++ source code of Daniel Scharstein.\n    Code follows the the Matlab source code of Deqing Sun.\n\n    Returns:\n        np.ndarray: Color wheel\n    \"\"\"\n\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n    colorwheel = np.zeros((ncols, 3))\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)\n    col = col+RY\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)\n    colorwheel[col:col+YG, 1] = 255\n    col = col+YG\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)\n    col = col+GC\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)\n    colorwheel[col:col+CB, 2] = 255\n    col = col+CB\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)\n    col = col+BM\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)\n    colorwheel[col:col+MR, 0] = 255\n    return colorwheel\n\n\ndef flow_uv_to_colors(u, v, convert_to_bgr=False):\n    \"\"\"\n    Applies the flow color wheel to (possibly clipped) flow components u and v.\n\n    According to the C++ source code of Daniel Scharstein\n    According to the Matlab source code of Deqing Sun\n\n    Args:\n        u (np.ndarray): Input horizontal flow of shape [H,W]\n        v (np.ndarray): Input vertical flow of shape [H,W]\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)\n    colorwheel = make_colorwheel()  # shape [55x3]\n    ncols = colorwheel.shape[0]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    a = np.arctan2(-v, -u)/np.pi\n    fk = (a+1) / 2*(ncols-1)\n    k0 = np.floor(fk).astype(np.int32)\n    k1 = k0 + 1\n    k1[k1 == ncols] = 0\n    f = fk - k0\n    for i in range(colorwheel.shape[1]):\n        tmp = colorwheel[:,i]\n        col0 = tmp[k0] / 255.0\n        col1 = tmp[k1] / 255.0\n        col = (1-f)*col0 + f*col1\n        idx = (rad <= 1)\n        col[idx]  = 1 - rad[idx] * (1-col[idx])\n        col[~idx] = col[~idx] * 0.75   # out of range\n        # Note the 2-i => BGR instead of RGB\n        ch_idx = 2-i if convert_to_bgr else i\n        flow_image[:,:,ch_idx] = np.floor(255 * col)\n    return flow_image\n\n\ndef flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):\n    \"\"\"\n    Expects a two dimensional flow image of shape.\n\n    Args:\n        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]\n        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    assert flow_uv.ndim == 3, 'input flow must have three dimensions'\n    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'\n    if clip_flow is not None:\n        flow_uv = np.clip(flow_uv, 0, clip_flow)\n    u = flow_uv[:,:,0]\n    v = flow_uv[:,:,1]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    rad_max = np.max(rad)\n    epsilon = 1e-5\n    u = u / (rad_max + epsilon)\n    v = v / (rad_max + epsilon)\n    return flow_uv_to_colors(u, v, convert_to_bgr)"
  },
  {
    "path": "thirdparty/raft/utils/frame_utils.py",
    "content": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nTAG_CHAR = np.array([202021.25], np.float32)\n\ndef readFlow(fn):\n    \"\"\" Read .flo file in Middlebury format\"\"\"\n    # Code adapted from:\n    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy\n\n    # WARNING: this will work on little-endian architectures (eg Intel x86) only!\n    # print 'fn = %s'%(fn)\n    with open(fn, 'rb') as f:\n        magic = np.fromfile(f, np.float32, count=1)\n        if 202021.25 != magic:\n            print('Magic number incorrect. Invalid .flo file')\n            return None\n        else:\n            w = np.fromfile(f, np.int32, count=1)\n            h = np.fromfile(f, np.int32, count=1)\n            # print 'Reading %d x %d flo file\\n' % (w, h)\n            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))\n            # Reshape data into 3D array (columns, rows, bands)\n            # The reshape here is for visualization, the original code is (w,h,2)\n            return np.resize(data, (int(h), int(w), 2))\n\ndef readPFM(file):\n    file = open(file, 'rb')\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header == b'PF':\n        color = True\n    elif header == b'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(rb'^(\\d+)\\s(\\d+)\\s$', file.readline())\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0: # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>' # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data\n\ndef writeFlow(filename,uv,v=None):\n    \"\"\" Write optical flow to file.\n    \n    If v is None, uv is assumed to contain both u and v channels,\n    stacked in depth.\n    Original code by Deqing Sun, adapted from Daniel Scharstein.\n    \"\"\"\n    nBands = 2\n\n    if v is None:\n        assert(uv.ndim == 3)\n        assert(uv.shape[2] == 2)\n        u = uv[:,:,0]\n        v = uv[:,:,1]\n    else:\n        u = uv\n\n    assert(u.shape == v.shape)\n    height,width = u.shape\n    f = open(filename,'wb')\n    # write the header\n    f.write(TAG_CHAR)\n    np.array(width).astype(np.int32).tofile(f)\n    np.array(height).astype(np.int32).tofile(f)\n    # arrange into matrix form\n    tmp = np.zeros((height, width*nBands))\n    tmp[:,np.arange(width)*2] = u\n    tmp[:,np.arange(width)*2 + 1] = v\n    tmp.astype(np.float32).tofile(f)\n    f.close()\n\n\ndef readFlowKITTI(filename):\n    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)\n    flow = flow[:,:,::-1].astype(np.float32)\n    flow, valid = flow[:, :, :2], flow[:, :, 2]\n    flow = (flow - 2**15) / 64.0\n    return flow, valid\n\ndef readDispKITTI(filename):\n    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0\n    valid = disp > 0.0\n    flow = np.stack([-disp, np.zeros_like(disp)], -1)\n    return flow, valid\n\n\ndef writeFlowKITTI(filename, uv):\n    uv = 64.0 * uv + 2**15\n    valid = np.ones([uv.shape[0], uv.shape[1], 1])\n    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)\n    cv2.imwrite(filename, uv[..., ::-1])\n    \n\ndef read_gen(file_name, pil=False):\n    ext = splitext(file_name)[-1]\n    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':\n        return Image.open(file_name)\n    elif ext == '.bin' or ext == '.raw':\n        return np.load(file_name)\n    elif ext == '.flo':\n        return readFlow(file_name).astype(np.float32)\n    elif ext == '.pfm':\n        flow = readPFM(file_name).astype(np.float32)\n        if len(flow.shape) == 2:\n            return flow\n        else:\n            return flow[:, :, :-1]\n    return []"
  },
  {
    "path": "thirdparty/raft/utils/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \"\"\" Pads images such that dimensions are divisible by 8 \"\"\"\n    def __init__(self, dims, mode='sintel'):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8\n        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8\n        if mode == 'sintel':\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]\n        else:\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]\n\n    def pad(self, *inputs):\n        return [F.pad(x, self._pad, mode='replicate') for x in inputs]\n\n    def unpad(self,x):\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]\n        return x[..., c[0]:c[1], c[2]:c[3]]\n\ndef forward_interpolate(flow):\n    flow = flow.detach().cpu().numpy()\n    dx, dy = flow[0], flow[1]\n\n    ht, wd = dx.shape\n    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))\n\n    x1 = x0 + dx\n    y1 = y0 + dy\n    \n    x1 = x1.reshape(-1)\n    y1 = y1.reshape(-1)\n    dx = dx.reshape(-1)\n    dy = dy.reshape(-1)\n\n    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)\n    x1 = x1[valid]\n    y1 = y1[valid]\n    dx = dx[valid]\n    dy = dy[valid]\n\n    flow_x = interpolate.griddata(\n        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)\n\n    flow_y = interpolate.griddata(\n        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)\n\n    flow = np.stack([flow_x, flow_y], axis=0)\n    return torch.from_numpy(flow).float()\n\n\ndef bilinear_sampler(img, coords, mode='bilinear', mask=False):\n    \"\"\" Wrapper for grid_sample, uses pixel coordinates \"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1,1], dim=-1)\n    xgrid = 2*xgrid/(W-1) - 1\n    ygrid = 2*ygrid/(H-1) - 1\n\n    grid = torch.cat([xgrid, ygrid], dim=-1)\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\ndef coords_grid(batch, ht, wd):\n    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\ndef upflow8(flow, mode='bilinear'):\n    new_size = (8 * flow.shape[2], 8 * flow.shape[3])\n    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)\n\ndef upflow(flow, mode='bilinear', scale=8):\n    new_size = (scale * flow.shape[2], scale * flow.shape[3])\n    return  scale * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)"
  },
  {
    "path": "thirdparty/vsd/inout.py",
    "content": "# Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)\n# Center for Machine Perception, Czech Technical University in Prague\n\nimport struct\nimport itertools\nimport numpy as np\nimport scipy.misc\n\n\ndef load_depth(path):\n    d = scipy.misc.imread(path)\n    d = d.astype(np.float32)\n    return d\n\n\ndef load_ply(path):\n    \"\"\"\n    Loads a 3D mesh model from a PLY file.\n    :param path: Path to a PLY file.\n    :return: The loaded model given by a dictionary with items:\n    'pts' (nx3 ndarray), 'normals' (nx3 ndarray), 'colors' (nx3 ndarray),\n    'faces' (mx3 ndarray) - the latter three are optional.\n    \"\"\"\n    f = open(path, 'r')\n\n    n_pts = 0\n    n_faces = 0\n    face_n_corners = 3 # Only triangular faces are supported\n    pt_props = []\n    face_props = []\n    is_binary = False\n    header_vertex_section = False\n    header_face_section = False\n\n    # Read header\n    while True:\n        line = f.readline().rstrip('\\n').rstrip('\\r') # Strip the newline character(s)\n        if line.startswith('element vertex'):\n            n_pts = int(line.split()[-1])\n            header_vertex_section = True\n            header_face_section = False\n        elif line.startswith('element face'):\n            n_faces = int(line.split()[-1])\n            header_vertex_section = False\n            header_face_section = True\n        elif line.startswith('element'): # Some other element\n            header_vertex_section = False\n            header_face_section = False\n        elif line.startswith('property') and header_vertex_section:\n            # (name of the property, data type)\n            pt_props.append((line.split()[-1], line.split()[-2]))\n        elif line.startswith('property list') and header_face_section:\n            elems = line.split()\n            if elems[-1] == 'vertex_indices':\n                # (name of the property, data type)\n                face_props.append(('n_corners', elems[2]))\n                for i in range(face_n_corners):\n                    face_props.append(('ind_' + str(i), elems[3]))\n            else:\n                print(('Warning: Not supported face property: ' + elems[-1]))\n        elif line.startswith('format'):\n            if 'binary' in line:\n                is_binary = True\n        elif line.startswith('end_header'):\n            break\n\n    # Prepare data structures\n    model = {}\n    model['pts'] = np.zeros((n_pts, 3), np.float)\n    if n_faces > 0:\n        model['faces'] = np.zeros((n_faces, face_n_corners), np.float)\n\n    pt_props_names = [p[0] for p in pt_props]\n    is_normal = False\n    if {'nx', 'ny', 'nz'}.issubset(set(pt_props_names)):\n        is_normal = True\n        model['normals'] = np.zeros((n_pts, 3), np.float)\n\n    is_color = False\n    if {'red', 'green', 'blue'}.issubset(set(pt_props_names)):\n        is_color = True\n        model['colors'] = np.zeros((n_pts, 3), np.float)\n\n    is_texture = False\n    if {'texture_u', 'texture_v'}.issubset(set(pt_props_names)):\n        is_texture = True\n        model['texture_uv'] = np.zeros((n_pts, 2), np.float)\n\n    formats = { # For binary format\n        'float': ('f', 4),\n        'double': ('d', 8),\n        'int': ('i', 4),\n        'uchar': ('B', 1)\n    }\n\n    # Load vertices\n    for pt_id in range(n_pts):\n        prop_vals = {}\n        load_props = ['x', 'y', 'z', 'nx', 'ny', 'nz',\n                      'red', 'green', 'blue', 'texture_u', 'texture_v']\n        if is_binary:\n            for prop in pt_props:\n                format = formats[prop[1]]\n                val = struct.unpack(format[0], f.read(format[1]))[0]\n                if prop[0] in load_props:\n                    prop_vals[prop[0]] = val\n        else:\n            elems = f.readline().rstrip('\\n').rstrip('\\r').split()\n            for prop_id, prop in enumerate(pt_props):\n                if prop[0] in load_props:\n                    prop_vals[prop[0]] = elems[prop_id]\n\n        model['pts'][pt_id, 0] = float(prop_vals['x'])\n        model['pts'][pt_id, 1] = float(prop_vals['y'])\n        model['pts'][pt_id, 2] = float(prop_vals['z'])\n\n        if is_normal:\n            model['normals'][pt_id, 0] = float(prop_vals['nx'])\n            model['normals'][pt_id, 1] = float(prop_vals['ny'])\n            model['normals'][pt_id, 2] = float(prop_vals['nz'])\n\n        if is_color:\n            model['colors'][pt_id, 0] = float(prop_vals['red'])\n            model['colors'][pt_id, 1] = float(prop_vals['green'])\n            model['colors'][pt_id, 2] = float(prop_vals['blue'])\n\n        if is_texture:\n            model['texture_uv'][pt_id, 0] = float(prop_vals['texture_u'])\n            model['texture_uv'][pt_id, 1] = float(prop_vals['texture_v'])\n\n    # Load faces\n    for face_id in range(n_faces):\n        prop_vals = {}\n        if is_binary:\n            for prop in face_props:\n                format = formats[prop[1]]\n                val = struct.unpack(format[0], f.read(format[1]))[0]\n                if prop[0] == 'n_corners':\n                    if val != face_n_corners:\n                        print('Error: Only triangular faces are supported.')\n                        print(('Number of face corners: ' + str(val)))\n                        exit(-1)\n                else:\n                    prop_vals[prop[0]] = val\n        else:\n            elems = f.readline().rstrip('\\n').rstrip('\\r').split()\n            for prop_id, prop in enumerate(face_props):\n                if prop[0] == 'n_corners':\n                    if int(elems[prop_id]) != face_n_corners:\n                        print('Error: Only triangular faces are supported.')\n                        print(('Number of face corners: ' + str(int(elems[prop_id]))))\n                        exit(-1)\n                else:\n                    prop_vals[prop[0]] = elems[prop_id]\n\n        model['faces'][face_id, 0] = int(prop_vals['ind_0'])\n        model['faces'][face_id, 1] = int(prop_vals['ind_1'])\n        model['faces'][face_id, 2] = int(prop_vals['ind_2'])\n\n    f.close()\n\n    return model\n"
  },
  {
    "path": "tools/eval.py",
    "content": "#CERTIFICATED\nimport torch\nimport numpy as np \nimport tensorboard\nfrom pathlib import Path\nimport json\nimport random\nimport re\nimport torch.backends.cudnn as cudnn\nimport torch.multiprocessing as mp\nimport time\nimport fire\nimport torch.distributed as dist\nimport os\nfrom collections import defaultdict\nimport ast \nimport flow_vis\nimport copy\n\nfrom utils.progress_bar import ProgressBar\nfrom utils.log_tool import SimpleModelLog\nfrom data.preprocess import merge_batch, get_dataloader, get_dataloader_deepim # merge_second_batch_multigpu\nfrom utils.config_io import merge_cfg, save_cfg\nimport torchplus\nfrom builder import (\n    dataset_builder,\n    input_reader_builder,\n    lr_scheduler_builder,\n    optimizer_builder,\n    rnnpose_builder\n)\nfrom utils.distributed_utils import dist_init, average_gradients, DistModule, ParallelWrapper, DistributedSequatialSampler, DistributedGivenIterationSampler, DistributedGivenIterationSamplerEpoch \nfrom utils.util import modify_parameter_name_with_map\nfrom utils.eval_metric import *\nfrom config.default import get_cfg\n\n\n\nGLOBAL_GPUS_PER_DEVICE = 1  \nGLOBAL_STEP = 0\nRANK=-1\nWORLD_SIZE=-1\n\n\ndef load_example_to_device(example,\n                             device=None) -> dict:\n    example_torch = {}\n\n    for k, v in example.items():  \n        if k in ['idx', 'class_name']:\n            example_torch[k]=v\n            continue\n\n        if type(v) == list:\n            example_torch[k] = [item.to(device=device) for item in v]\n        else:\n            example_torch[k] = v.to(device=device)\n\n    return example_torch\ndef build_network(model_cfg, measure_time=False, testing=False):\n    net = rnnpose_builder.build(\n        model_cfg, measure_time=measure_time, testing=testing)\n    return net\n\n\ndef _worker_init_fn(worker_id):\n    global GLOBAL_STEP\n    time_seed = GLOBAL_STEP\n    np.random.seed(time_seed + worker_id)\n    print(f\"WORKER {worker_id} seed:\", np.random.get_state()[1][0])\n\n\ndef freeze_params(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    remain_params = []\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                continue\n        remain_params.append(p)\n    return remain_params\n\n\ndef freeze_params_v2(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                p.requires_grad = False\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                p.requires_grad = False\n\n\ndef filter_param_dict(state_dict: dict, include: str = None, exclude: str = None):\n    assert isinstance(state_dict, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    res_dict = {}\n    for k, p in state_dict.items():\n        if include_re is not None:\n            if include_re.match(k) is None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is not None:\n                continue\n        res_dict[k] = p\n    return res_dict\n\ndef chk_rank(rank_, use_dist=False):\n    if not use_dist:\n        return True\n    global RANK\n    if RANK<0:\n        RANK=dist.get_rank()\n    cur_rank = RANK#dist.get_rank()\n    # self.world_size = dist.get_world_size()\n    return cur_rank == rank_\n\ndef get_rank(use_dist=False):\n    if not use_dist:\n        return 0\n    else:\n        # return dist.get_rank()\n        global RANK \n        if RANK<0:\n            RANK=dist.get_rank()\n        return RANK \n\ndef get_world(use_dist):\n    if not use_dist:\n        return 1\n    else:\n        global WORLD_SIZE \n        if WORLD_SIZE<0:\n            WORLD_SIZE=dist.get_world_size()\n        return WORLD_SIZE #dist.get_world_size()\ndef get_ngpus_per_node():\n    global GLOBAL_GPUS_PER_DEVICE\n    return GLOBAL_GPUS_PER_DEVICE\n\n\ndef multi_proc_train(\n          config_path,\n          model_dir,\n          use_apex,\n          world_size,\n          result_path=None,\n          create_folder=False,\n          display_step=50,\n          summary_step=5,\n          pretrained_path=None,\n          pretrained_include=None,\n          pretrained_exclude=None,\n          pretrained_param_map=None,\n          freeze_include=None,\n          freeze_exclude=None,\n          measure_time=False,\n          resume=False,\n          use_dist=False,\n          gpus_per_node=1,\n          start_gpu_id=0,\n          optim_eval=False,\n          seed=7,\n          dist_port=\"23335\",\n         force_resume_step=None,\n         batch_size=None,\n         apex_opt_level='O0'\n          ):\n    \n    params = {\n          \"config_path\": config_path,\n          \"model_dir\": model_dir,\n          \"use_apex\": use_apex,\n          \"result_path\": result_path,\n          \"create_folder\": create_folder,\n          \"display_step\": display_step,\n          \"summary_step\": summary_step,\n          \"pretrained_path\": pretrained_path,\n          \"pretrained_include\": pretrained_include,\n          \"pretrained_exclude\": pretrained_exclude,\n          \"pretrained_param_map\": pretrained_param_map,\n          \"freeze_include\": freeze_include,\n          \"freeze_exclude\": freeze_exclude,\n        #   \"multi_gpu\": multi_gpu,\n          \"measure_time\": measure_time,\n          \"resume\": resume,\n          \"use_dist\": use_dist,\n          \"gpus_per_node\": gpus_per_node,\n          \"optim_eval\": optim_eval,\n          \"seed\": seed,\n          \"dist_port\": dist_port,\n          \"world_size\": world_size,\n          \"force_resume_step\":force_resume_step,\n          \"batch_size\": batch_size,\n          \"apex_opt_level\":apex_opt_level\n    }\n    from types import SimpleNamespace \n    params = SimpleNamespace(**params)\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(\n        str(x) for x in range(start_gpu_id, start_gpu_id+gpus_per_node))\n    print(f\"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}\"  )\n\n    mp.spawn(train_worker, nprocs=gpus_per_node,\n                args=( params,) )\n\ndef train_worker(rank, params):\n    global RANK, WORLD_SIZE\n    RANK = rank\n    WORLD_SIZE=params.world_size\n    \n    eval(config_path=params.config_path,\n          model_dir=params.model_dir,\n          use_apex=params.use_apex,\n          result_path=params.result_path,\n          create_folder=params.create_folder,\n          display_step=params.display_step,\n          pretrained_path=params.pretrained_path,\n          pretrained_include=params.pretrained_include,\n          pretrained_exclude=params.pretrained_exclude,\n          pretrained_param_map=params.pretrained_param_map,\n          freeze_include=params.freeze_include,\n          freeze_exclude=params.freeze_exclude,\n          measure_time=params.measure_time,\n          resume=params.resume,\n          use_dist=params.use_dist,\n          dist_port=params.dist_port,\n          gpus_per_node=params.gpus_per_node,\n          optim_eval=params.optim_eval,\n          seed=params.seed,\n          force_resume_step=params.force_resume_step,\n          batch_size = params.batch_size,\n          apex_opt_level=params.apex_opt_level\n          ) \n\n\ndef eval(\n         config_path,\n          model_dir,\n          use_apex,\n          result_path=None,\n          create_folder=False,\n          display_step=50,\n          summary_step=5,\n          pretrained_path=None,\n          pretrained_include=None,\n          pretrained_exclude=None,\n          pretrained_param_map=None,\n          freeze_include=None,\n          freeze_exclude=None,\n          multi_gpu=False,\n          measure_time=False,\n          resume=False,\n          use_dist=False,\n          dist_port=\"23335\",\n          gpus_per_node=1,\n          optim_eval=False,\n          seed=7,\n          force_resume_step=None,\n          batch_size=None,\n          apex_opt_level='O0',\n          verbose=False\n          ):\n    \"\"\"train a VoxelNet model specified by a config file.\n    \"\"\"\n\n    print(\"force_resume_step:\", force_resume_step)\n    print(\"torch.cuda.is_available()=\", torch.cuda.is_available())\n    print(\"torch.version.cuda=\",torch.version.cuda) \n    dist_url=f\"tcp://127.0.0.1:{dist_port}\"\n    print(f\"dist_url={dist_url}\", flush=True)\n    global RANK, WORLD_SIZE\n    # RANK, WORLD_SIZE=rank, world_size\n    if RANK<0:\n        RANK=0\n    if WORLD_SIZE<0:\n        WORLD_SIZE=1\n\n    global GLOBAL_GPUS_PER_DEVICE\n    GLOBAL_GPUS_PER_DEVICE = gpus_per_node\n\n  \n\n    ######################################## initialize the distributed env #########################################\n    if use_dist:\n        if use_apex:\n            dist.init_process_group(\n                backend=\"nccl\", init_method=dist_url, world_size=get_world(use_dist), rank=get_rank(use_dist))\n        else:\n            # rank, world_size = dist_init(str(dist_port))\n            dist.init_process_group(\n                backend=\"nccl\", init_method=dist_url, world_size=get_world(use_dist), rank=get_rank(use_dist))\n    \n    print(get_rank(use_dist)%GLOBAL_GPUS_PER_DEVICE, flush=True)\n    #set cuda device number\n    torch.cuda.set_device(get_rank(use_dist)%GLOBAL_GPUS_PER_DEVICE)\n\n    ############################################ create folders ############################################\n    print(f\"Set seed={seed}\", flush=True)\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n\n    model_dir = str(Path(model_dir).resolve())\n    model_dir = Path(model_dir)\n    if chk_rank(0, use_dist):\n        if not resume and model_dir.exists():\n            raise ValueError(\"model dir exists and you don't specify resume.\")\n            print(\"Warning: model dir exists and you don't specify resume.\")\n\n        model_dir.mkdir(parents=True, exist_ok=True)\n    if result_path is None:\n        result_path = model_dir / 'results'\n    config_file_bkp = \"pipeline.config\"\n\n    ############################################# read config proto ############################################\n    config = merge_cfg(\n        [config_path], intersection=True)\n    if chk_rank(0, use_dist):\n        print(json.dumps(config, indent=4))\n\n    if chk_rank(0, use_dist):\n        # save_cfg([default_config_path, custom_config_path],\n        save_cfg([config_path, config_path],\n                 str(model_dir / config_file_bkp))\n\n    #update the global config object\n    get_cfg().merge(config.get(\"BASIC\",{}),\"BASIC\" )  \n\n    input_cfg = config.train_input_reader\n    eval_input_cfg = config.eval_input_reader\n    model_cfg = config.model\n    train_cfg = config.train_config\n    optimizer_cfg = train_cfg.optimizer\n    loss_scale = train_cfg.loss_scale_factor\n\n\n    ############################################# Update default options ############################################\n\n    if batch_size is not None:\n        input_cfg.batch_size = batch_size \n        eval_input_cfg.batch_size = batch_size\n    print(input_cfg.batch_size)\n    \n    ############################################# build network, optimizer etc. ############################################\n    #dummy dataset\n    dataset_tmp = input_reader_builder.build(\n        eval_input_cfg,\n        training=False,\n    )\n \n    model_cfg.obj_seqs = copy.copy(dataset_tmp._dataset.infos[\"seqs\"])\n\n    net = build_network(model_cfg, measure_time)  # .to(device)\n    net.cuda()\n\n    fastai_optimizer = optimizer_builder.build(\n        optimizer_cfg, net,\n        mixed=False,\n        loss_scale=loss_scale)\n\n    print(\"# parameters:\", len(list(net.parameters())))\n\n    ############################################# load pretrained model ############################################\n    if pretrained_path is not None:\n        model_dict = net.state_dict()\n        pretrained_dict = torch.load(pretrained_path, map_location='cpu')\n\n        if verbose:\n            print(\"Pretrained keys:\", pretrained_dict.keys())\n            print(\"Model keys:\", model_dict.keys())\n\n        pretrained_dict = filter_param_dict(\n            pretrained_dict, pretrained_include, pretrained_exclude)\n\n        pretrained_dict = modify_parameter_name_with_map(\n            pretrained_dict, ast.literal_eval(str(pretrained_param_map)))\n        new_pretrained_dict = {}\n        \n        for k, v in pretrained_dict.items():\n            if k in model_dict and v.shape == model_dict[k].shape:\n                new_pretrained_dict[k] = v\n            else:\n                print(\"Fail to load:\", k )\n\n        model_dict.update(new_pretrained_dict)\n        net.load_state_dict(model_dict)\n        freeze_params_v2(dict(net.named_parameters()),\n                         freeze_include, freeze_exclude)\n        net.clear_global_step()\n        # net.clear_metrics()\n        del pretrained_dict\n    else:\n    ############################################# try to resume from the latest chkpt ############################################\n        torchplus.train.try_restore_latest_checkpoints(model_dir, [net])\n    # torchplus.train.try_restore_latest_checkpoints(model_dir,\n    #                                                [fastai_optimizer])\n\n    ######################################## parallel the network  #########################################\n    if use_dist:\n        if use_apex:\n            import apex\n            net, amp_optimizer = apex.amp.initialize(net.cuda(\n            ), fastai_optimizer, opt_level=\"O0\", keep_batchnorm_fp32=None, loss_scale=None)\n            net_parallel = apex.parallel.DistributedDataParallel(net)\n        else:\n            # net_parallel = ParallelWrapper(net.cuda(), 'dist')\n            # amp_optimizer = fastai_optimizer\n            amp_optimizer = optimizer_builder.build(\n                optimizer_cfg, net,\n                mixed=False,\n                loss_scale=loss_scale)\n            net_parallel = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_rank(use_dist)], output_device=get_rank(use_dist) ,find_unused_parameters=True)\n    else:\n        net_parallel = net.cuda()\n\n    ############################################# build lr_scheduler ############################################\n    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, amp_optimizer,\n                                              train_cfg.steps)\n\n    float_dtype = torch.float32\n    ######################################## build dataloaders #########################################\n    if use_dist:\n        num_gpu = 1\n        collate_fn = merge_batch\n    else:\n        raise NotImplementedError\n\n    print(f\"MULTI-GPU: using {num_gpu} GPU(s)\")\n\n        \n    ######################\n    # PREPARE INPUT\n    ######################\n    dataset = input_reader_builder.build(\n        input_cfg,\n        training=True,\n    )\n    eval_dataset = input_reader_builder.build(\n        eval_input_cfg,\n        training=False,\n    )\n\n\n    if use_dist:\n        train_sampler = DistributedGivenIterationSamplerEpoch(\n            dataset, train_cfg.steps, input_cfg.batch_size, last_iter=net.get_global_step()-1, review_cycle=-1)\n        # train_sampler=DistributedSequatialSampler(dataset)\n        shuffle = False\n        eval_sampler = DistributedSequatialSampler(eval_dataset)\n\n    else:\n        train_sampler = None\n        eval_sampler = None\n        eval_train_sampler = None\n        shuffle = True\n\n    dataloader, neighborhood_limits=get_dataloader_deepim(dataset=dataset,\n                                                kpconv_config=model_cfg.descriptor_net.keypoints_detector_3d,\n                                                batch_size=input_cfg.batch_size * num_gpu,\n                                                shuffle=shuffle,\n                                                num_workers= 2 ,#input_cfg.preprocess.num_workers * num_gpu,\n                                                sampler=train_sampler \n                                        )\n    eval_dataloader, _ =get_dataloader_deepim(dataset=eval_dataset,\n                                                kpconv_config=model_cfg.descriptor_net.keypoints_detector_3d,\n                                    batch_size=eval_input_cfg.batch_size,\n                                    shuffle=False,\n                                    num_workers= 2, #eval_input_cfg.preprocess.num_workers,\n                                    sampler=eval_sampler,\n                                    neighborhood_limits=neighborhood_limits\n                                   )\n\n    #########################################################################################\n    #                                            TRAINING\n    ##########################################################################################\n    model_logging = SimpleModelLog(model_dir, disable=get_rank(use_dist) != 0)\n    model_logging.open()\n    start_step = net.get_global_step()\n    total_step = train_cfg.steps\n    t = time.time()\n    steps_per_eval = train_cfg.steps_per_eval\n\n    amp_optimizer.zero_grad()\n    step = start_step\n    epoch = 0\n    net_parallel.eval()\n\n    classes=list(set(eval_dataset.dataset.infos['seqs']))\n    \n    evaluator = dict([(c,LineMODEvaluator(f\"{c}\",result_path) ) for c in classes ] )\n\n    ang_errs=[]\n    trans_errs=[]\n    try:\n        for example in eval_dataloader:\n            global GLOBAL_STEP\n            GLOBAL_STEP = step\n\n\n            lr_scheduler.step(net.get_global_step())\n            example_torch=load_example_to_device(example, device=torch.device(\"cuda\"))\n\n            batch_size = example[\"image\"].shape[0]\n            with torch.no_grad():\n                # t1=time.time()\n                ret_dict = net_parallel(example_torch)\n                # print(\"time:\", time.time()-t1)\n\n            eval_results=evaluator[example['class_name'][0]].evaluate_rnnpose(ret_dict, example_torch )\n            ang_errs.append(eval_results['ang_err'])\n            trans_errs.append(eval_results['trans_err'])\n\n            if step%10 ==0:\n                model_logging.log_metrics({\n                    \"ang_err\": float(ang_errs[-1]), \n                    \"trans_err\": float(trans_errs[-1]), \n                }, GLOBAL_STEP)\n\n                model_logging.log_images(\n                    {\n                        \"pc_proj_vis\": eval_results['pc_proj_vis'].transpose([2,0,1])[None], # HWC->NCHW\n                        \"pc_proj_vis_pred\": eval_results['pc_proj_vis_pred'].transpose([2,0,1])[None], # HWC->NCHW\n                        \"syn_img\": torch.cat(ret_dict[\"syn_img\"], dim=0).detach().cpu(),\n                        # \"image\": example[\"image\"].cpu(),\n                        # \"image_f\": (example[\"image\"].cpu()+ret_dict[\"syn_img\"][0].detach().cpu())/2,\n                    }, GLOBAL_STEP, prefix=\"\")\n\n\n\n\n            net.update_global_step()\n            metrics = defaultdict(dict)\n            GLOBAL_STEP = net.get_global_step()\n\n            step += 1\n\n\n        for k in evaluator:\n            print(f\"###############Evaluation results of class {k}###############\")\n            evaluator[k].summarize()\n                \n    except Exception as e:\n        model_logging.log_text(str(e), step)\n        raise e\n    finally:\n        model_logging.close()\n\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/generate_data_info_deepim_0_orig.py",
    "content": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemod_config\n\n\ndef parse_pose_file(file):\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n\n    poses = []\n    for line in lines:\n        poses.append(\n            np.array([np.float32(l) for l in line.split()],\n                     dtype=np.float32).reshape((3, 4))\n        )\n    return poses\n\n\ndef parse_calib_file(file):\n    info = {}\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n    for i, l in enumerate(lines):\n        nums = np.array([np.float32(x)\n                         for x in l.split(' ')[1:]], dtype=np.float32)\n        if i < 4:\n            info[f\"calib/P{i}\"] = nums.reshape((3, 4))\n        else:\n            info[f\"calib/Tr_velo_to_cam\"] = nums.reshape((3, 4))\n    return info\n\n\n# def create_data_info(data_root, saving_path, is_test_data=False):\n# def create_data_info(data_root, saving_path, data_type='train'):\n# def create_data_info(data_root, saving_path, training_data_ratio=0.8, shuffle=True, ):\ndef create_data_info(data_root, saving_path, with_assertion=True):\n    \"\"\"[summary]\n        info structure:\n        {\n            0:[\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n                },\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n               \n                },\n            }\n            ...\n            ],\n            1:[\n\n            ]\n            ...\n        }\n\n    \"\"\"\n    idx2class = {\n        1: \"ape\",\n        2: \"benchvise\",\n        # 3: 'bowl',\n        4: \"camera\",\n        5: \"can\",\n        6: \"cat\",\n        # 7: 'cup',\n        8: \"driller\",\n        9: \"duck\",\n        10: \"eggbox\",\n        11: \"glue\",\n        12: \"holepuncher\",\n        13: \"iron\",\n        14: \"lamp\",\n        15: \"phone\",\n    }\n    class2idx = dict([[idx2class[k],k ] for k in idx2class.keys() ])\n\n    seqs=class2idx.keys()\n\n    observed_dir = os.path.join('', 'data/observed')\n    gt_observed_dir = os.path.join('', 'data/gt_observed')\n    rendered_dir = os.path.join('', 'data/rendered')\n    set_split_dir = os.path.join('','image_set/observed')\n\n    # max_items_per_seq=10000#8000#100#10000#2000\n    # create training data\n    res = {}\n   \n    for seq in seqs:\n        res[seq] = []\n\n        rgb_orig_dir = os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")        \n        rgb_noisy_rendered_dir = os.path.join(rendered_dir, seq )        \n\n        depth_orig_dir= os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")    \n        depth_rendered_dir= os.path.join(gt_observed_dir, seq)    \n        depth_noisy_rendered_dir= os.path.join(rendered_dir, seq)    \n\n        gt_pose_dir = os.path.join(gt_observed_dir, seq)\n        noisy_pose_dir = os.path.join(rendered_dir, seq)\n\n        label_dir = os.path.join(observed_dir,  f\"{class2idx[seq]:02d}\" )\n\n\n        rgb_orig_paths = glob.glob(r'{}/*color.png'.format(rgb_orig_dir) ) \n        train_split_file=os.path.join(data_root,set_split_dir, f\"{seq}_train.txt\")\n        \n        with open(train_split_file, 'r') as f:\n            train_split = f.readlines()\n            train_split = [ int(t.split('/')[-1] ) for t in train_split]\n\n        rgb_orig_paths.sort(key=lambda s: int(re.split( '\\.|_|-' ,os.path.basename(s))[0]) )\n\n\n        NUM_RENDERED=10\n        \n        for idx in train_split:\n            \n            #original data paths\n            gt_pose=np.loadtxt(os.path.join(data_root,gt_pose_dir, f\"{idx:06d}-pose.txt\"), skiprows=1).reshape(3,4)\n            rgb_orig = os.path.join(rgb_orig_dir, f\"{idx:06d}-color.png\" )\n            depth_orig=os.path.join(depth_orig_dir, f\"{idx:06d}-depth.png\" )\n            depth_rendered = os.path.join(depth_rendered_dir, f\"{idx:06d}-depth.png\")\n            label_orig=os.path.join(label_dir, f\"{idx:06d}-label.png\" )\n\n            #rendered data paths\n            rgb_noisy_rendered = [os.path.join(rgb_noisy_rendered_dir, f\"{idx:06d}_{i}-color.png\" ) for i in range(NUM_RENDERED) ]\n            depth_noisy_rendered = [os.path.join(depth_noisy_rendered_dir, f\"{idx:06d}_{i}-depth.png\" ) for i in range(NUM_RENDERED) ]\n\n            pose_noisy_rendered = [os.path.join(data_root, noisy_pose_dir, f\"{idx:06d}_{i}-pose.txt\" ) for i in range(NUM_RENDERED) ]\n            pose_noisy_rendered = [np.loadtxt(p, skiprows=1).reshape(3,4) for p in pose_noisy_rendered ]\n\n            #generate data pairs\n\n            for noisy_data_idx in range(NUM_RENDERED):\n                if with_assertion:\n                    assert os.path.exists(os.path.join(data_root, rgb_orig) ), os.path.join(data_root, rgb_orig) \n                    assert os.path.exists(os.path.join(data_root, depth_orig) ), os.path.join(data_root, depth_orig) \n                    assert os.path.exists(os.path.join(data_root, label_orig) ), os.path.join(data_root, label_orig) \n                    assert os.path.exists(os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx]) ), os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx]) \n                    assert os.path.exists(os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] ) ), os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] ) \n\n                info = {\n                    \"index\": idx,\n                    # \"rgb_orig_path\": rgb_orig,\n                    \"rgb_observed_path\": rgb_orig,\n                    \"depth_observed_path\": depth_orig,\n                    \"depth_gt_observed_path\": depth_rendered,\n                    \"gt_pose\": gt_pose,\n\n                    \"rgb_noisy_rendered\": rgb_noisy_rendered[noisy_data_idx],\n                    \"depth_noisy_rendered\": depth_noisy_rendered[noisy_data_idx],\n                    \"pose_noisy_rendered\": pose_noisy_rendered[noisy_data_idx],\n\n                    \"model_points_path\": f\"{seq}.bin\",\n                    #legacy\n                    \"RT\": gt_pose,\n                    \"K\": linemod_config.linemod_K,\n                }\n                res[seq].append(info)\n\n                print(info['rgb_observed_path'], info['rgb_noisy_rendered'])\n    \n    train_saving_path=saving_path+'.train'\n    with open(train_saving_path, 'wb+') as f:\n        print(\"Total data amount:\", np.sum([len(res[r]) for r in res]))\n        pickle.dump(res, f)\n\n    # eval_saving_path=saving_path+'.eval'\n    # with open(eval_saving_path, 'wb+') as f:\n    #     print(\"Total data amount:\", np.sum([len(test_res[r]) for r in test_res]))\n    #     pickle.dump(test_res, f)\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/generate_data_info_deepim_1_syn.py",
    "content": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemod_config\n\n\ndef parse_pose_file(file):\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n\n    poses = []\n    for line in lines:\n        poses.append(\n            np.array([np.float32(l) for l in line.split()],\n                     dtype=np.float32).reshape((3, 4))\n        )\n    return poses\n\n\ndef parse_calib_file(file):\n    info = {}\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n    for i, l in enumerate(lines):\n        nums = np.array([np.float32(x)\n                         for x in l.split(' ')[1:]], dtype=np.float32)\n        if i < 4:\n            info[f\"calib/P{i}\"] = nums.reshape((3, 4))\n        else:\n            info[f\"calib/Tr_velo_to_cam\"] = nums.reshape((3, 4))\n    return info\n\n\ndef create_data_info(data_root, saving_path, with_assertion=False ):\n    \"\"\"[summary]\n        info structure:\n        {\n            0:[\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n                },\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n               \n                },\n            }\n            ...\n            ],\n            1:[\n\n            ]\n            ...\n        }\n\n    \"\"\"\n\n\n    idx2class = {\n        1: \"ape\",\n        2: \"benchvise\",\n        # 3: 'bowl',\n        4: \"camera\",\n        5: \"can\",\n        6: \"cat\",\n        # 7: 'cup',\n        8: \"driller\",\n        9: \"duck\",\n        10: \"eggbox\",\n        11: \"glue\",\n        12: \"holepuncher\",\n        13: \"iron\",\n        14: \"lamp\",\n        15: \"phone\",\n    }\n    class2idx = dict([[idx2class[k],k ] for k in idx2class.keys() ])\n\n    seqs=class2idx.keys()\n\n    observed_dir = os.path.join('', 'data/observed')\n    gt_observed_dir = os.path.join('', 'data/gt_observed')\n    rendered_dir = os.path.join('', 'data/rendered')\n    set_split_dir = os.path.join('','image_set/observed')\n\n    # create training data\n    res = {}\n    test_res = {}\n    for seq in seqs:\n        res[seq] = []\n\n        # rgb_orig_dir = os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")        \n        rgb_orig_dir = os.path.join(observed_dir, seq)        \n        rgb_noisy_rendered_dir = os.path.join(rendered_dir, seq )        \n\n        # depth_orig_dir= os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")    \n        depth_orig_dir= os.path.join(observed_dir, seq)    \n\n        # depth_renderd_dir= os.path.join(gt_observed_dir, seq)    \n        depth_rendered_dir= os.path.join(gt_observed_dir, seq)    \n        depth_noisy_rendered_dir= os.path.join(rendered_dir, seq)    \n\n        gt_pose_dir = os.path.join(gt_observed_dir, seq)\n        noisy_pose_dir = os.path.join(rendered_dir, seq)\n\n        label_dir = os.path.join(observed_dir, seq)\n\n\n        rgb_orig_paths = glob.glob(r'{}/*color.png'.format(rgb_orig_dir) ) \n\n        # train_split_file=os.path.join(data_root,set_split_dir, f\"{seq}_train.txt\")\n        train_split_file=os.path.join(data_root,set_split_dir, f\"LM6d_data_syn_train_observed_{seq}.txt\")\n\n        with open(train_split_file, 'r') as f:\n            train_split = f.readlines()\n            train_split = [ int(t.split('/')[-1] ) for t in train_split]\n\n\n\n        rgb_orig_paths.sort(key=lambda s: int(re.split( '\\.|_|-' ,os.path.basename(s))[0]) )\n\n\n        # data_num=len(image_paths[:max_items_per_seq] ) \n        # if shuffle: \n            \n        #     permute=np.random.permutation(data_num)\n        # else:\n        #     permute = np.arange(data_num)\n        # train_split=permute[:int(data_num*training_data_ratio)]\n        # eval_split= permute[int(data_num*training_data_ratio):]\n\n        # for idx in range(len(image_paths[:max_items_per_seq])):\n        NUM_RENDERED=1\n        for idx in train_split:\n            \n            #original data paths\n            gt_pose=np.loadtxt(os.path.join(data_root,gt_pose_dir, f\"{idx:06d}-pose.txt\"), skiprows=1).reshape(3,4)\n            rgb_orig = os.path.join(rgb_orig_dir, f\"{idx:06d}-color.png\" )\n            depth_orig=os.path.join(depth_orig_dir, f\"{idx:06d}-depth.png\" )\n            label_orig=os.path.join(label_dir, f\"{idx:06d}-label.png\" )\n\n            depth_rendered = os.path.join(depth_rendered_dir, f\"{idx:06d}-depth.png\")\n\n           \n\n\n            #rendered data paths\n            rgb_noisy_rendered = [os.path.join(rgb_noisy_rendered_dir, f\"{seq}_{idx:06d}_{i}-color.png\" ) for i in range(NUM_RENDERED) ]\n            depth_noisy_rendered = [os.path.join(depth_noisy_rendered_dir, f\"{seq}_{idx:06d}_{i}-depth.png\" ) for i in range(NUM_RENDERED) ]\n\n            pose_noisy_rendered = [os.path.join(data_root, noisy_pose_dir, f\"{seq}_{idx:06d}_{i}-pose.txt\" ) for i in range(NUM_RENDERED) ]\n            pose_noisy_rendered = [np.loadtxt(p, skiprows=1).reshape(3,4) for p in pose_noisy_rendered ]\n\n            #generate data pairs\n\n            for noisy_data_idx in range(NUM_RENDERED):\n                if with_assertion:\n                    assert os.path.exists(os.path.join(data_root, rgb_orig) ), os.path.join(data_root, rgb_orig) \n                    assert os.path.exists(os.path.join(data_root, depth_orig) ), os.path.join(data_root, depth_orig) \n                    assert os.path.exists(os.path.join(data_root, label_orig) ), os.path.join(data_root, label_orig)\n                    assert os.path.exists(os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx]) ), os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx])\n                    assert os.path.exists(os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] ) ), os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] )\n\n                    assert os.path.exists(os.path.join(data_root, depth_rendered) ), os.path.join(data_root, depth_rendered) \n                    # assert os.path.exists(os.path.join(data_root, pose_noisy_rendered[noisy_data_idx] ) ), os.path.join(data_root, pose_noisy_rendered[noisy_data_idx] )\n\n                info = {\n                     \"index\": idx,\n                    # \"rgb_orig_path\": rgb_orig,\n                    \"rgb_observed_path\": rgb_orig,\n                    \"depth_observed_path\": depth_orig,\n                    \"depth_gt_observed_path\": depth_rendered,\n                    \"gt_pose\": gt_pose,\n\n                    \"rgb_noisy_rendered\": rgb_noisy_rendered[noisy_data_idx],\n                    \"depth_noisy_rendered\": depth_noisy_rendered[noisy_data_idx],\n                    \"pose_noisy_rendered\": pose_noisy_rendered[noisy_data_idx],\n\n                    \"model_points_path\": f\"{seq}.bin\",\n                    #legacy\n                    \"RT\": gt_pose,\n                    \"K\": linemod_config.linemod_K,\n                }\n\n            print(info['rgb_observed_path'], info['rgb_noisy_rendered'])\n\n            res[seq].append(info)\n\n\n    train_saving_path=saving_path+'.train'\n    with open(train_saving_path, 'wb+') as f:\n        print(\"Total data amount:\", np.sum([len(res[r]) for r in res]))\n        pickle.dump(res, f)\n\n    # eval_saving_path=saving_path+'.eval'\n    # with open(eval_saving_path, 'wb+') as f:\n    #     print(\"Total data amount:\", np.sum([len(test_res[r]) for r in test_res]))\n    #     pickle.dump(test_res, f)\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/generate_data_info_deepim_2_posecnnval.py",
    "content": "import os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\nfrom data.linemod import linemod_config\nimport scipy.io as sio\n\n\n\ndef parse_pose_file(file):\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n\n    poses = []\n    for line in lines:\n        poses.append(\n            np.array([np.float32(l) for l in line.split()],\n                     dtype=np.float32).reshape((3, 4))\n        )\n    return poses\n\n\ndef parse_calib_file(file):\n    info = {}\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n    for i, l in enumerate(lines):\n        nums = np.array([np.float32(x)\n                         for x in l.split(' ')[1:]], dtype=np.float32)\n        if i < 4:\n            info[f\"calib/P{i}\"] = nums.reshape((3, 4))\n        else:\n            info[f\"calib/Tr_velo_to_cam\"] = nums.reshape((3, 4))\n    return info\n\n\ndef create_data_info(data_root, saving_path, with_assertion=True):\n    \"\"\"[summary]\n        info structure:\n        {\n            0:[\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n                },\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n               \n                },\n            }\n            ...\n            ],\n            1:[\n\n            ]\n            ...\n        }\n\n    \"\"\"\n    # seqs=['cat', 'ape', 'camera', 'duck', 'glue', 'iron', 'phone','benchvise', 'can', 'driller', 'eggbox', 'holepuncher', 'lamp']\n    idx2class = {\n        1: \"ape\",\n        2: \"benchvise\",\n        # 3: 'bowl',\n        4: \"camera\",\n        5: \"can\",\n        6: \"cat\",\n        # 7: 'cup',\n        8: \"driller\",\n        9: \"duck\",\n        10: \"eggbox\",\n        11: \"glue\",\n        12: \"holepuncher\",\n        13: \"iron\",\n        14: \"lamp\",\n        15: \"phone\",\n    }\n    class2idx = dict([[idx2class[k],k ] for k in idx2class.keys() ])\n\n    seqs=class2idx.keys()\n\n    observed_dir = os.path.join('', 'data/observed')\n    gt_observed_dir = os.path.join('', 'data/gt_observed')\n    # rendered_dir = os.path.join('', 'data/rendered')\n    rendered_dir = os.path.join('', 'data/rendered_val_PoseCNN')\n    set_split_dir = os.path.join('','image_set/observed')\n\n    # max_items_per_seq=10000#8000#100#10000#2000\n    # create training data\n    res = {}\n   \n    for seq in seqs:\n        res[seq] = []\n\n        rgb_orig_dir = os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")        \n        rgb_noisy_rendered_dir = os.path.join(rendered_dir,  f\"{class2idx[seq]:02d}\", seq )        \n\n        depth_orig_dir= os.path.join(observed_dir, f\"{class2idx[seq]:02d}\")    \n        depth_rendered_dir= os.path.join(gt_observed_dir, seq)    \n        depth_noisy_rendered_dir= os.path.join(rendered_dir, f\"{class2idx[seq]:02d}\", seq)    \n\n        gt_pose_dir = os.path.join(gt_observed_dir, seq)\n        noisy_pose_dir = os.path.join(rendered_dir,  f\"{class2idx[seq]:02d}\" ,seq)\n\n        label_dir = os.path.join(observed_dir,  f\"{class2idx[seq]:02d}\" )\n        meta_dir = os.path.join(observed_dir,  f\"{class2idx[seq]:02d}\" )\n\n\n        rgb_orig_paths = glob.glob(r'{}/*color.png'.format(rgb_orig_dir) ) \n        test_split_file=os.path.join(data_root,set_split_dir, f\"{seq}_test.txt\")\n        \n        with open(test_split_file, 'r') as f:\n            test_split = f.readlines()\n            test_split = [ int(t.split('/')[-1] ) for t in test_split]\n\n        rgb_orig_paths.sort(key=lambda s: int(re.split( '\\.|_|-' ,os.path.basename(s))[0]) )\n\n\n        # for idx in range(len(image_paths[:max_items_per_seq])):\n        NUM_RENDERED=1\n        \n        for idx in test_split:\n            \n            #original data paths\n            gt_pose=np.loadtxt(os.path.join(data_root,gt_pose_dir, f\"{idx:06d}-pose.txt\"), skiprows=1).reshape(3,4)\n            rgb_orig = os.path.join(rgb_orig_dir, f\"{idx:06d}-color.png\" )\n            depth_orig=os.path.join(depth_orig_dir, f\"{idx:06d}-depth.png\" )\n            depth_rendered = os.path.join(depth_rendered_dir, f\"{idx:06d}-depth.png\")\n            label_orig=os.path.join(label_dir, f\"{idx:06d}-label.png\" )\n            meta_path=os.path.join(meta_dir, f\"{idx:06d}-meta.mat\")\n\n            #rendered data paths\n            rgb_noisy_rendered = [os.path.join(rgb_noisy_rendered_dir, f\"{seq}_{idx:06d}_{i}-color.png\" ) for i in range(NUM_RENDERED) ]\n            depth_noisy_rendered = [os.path.join(depth_noisy_rendered_dir, f\"{seq}_{idx:06d}_{i}-depth.png\" ) for i in range(NUM_RENDERED) ]\n\n            pose_noisy_rendered = [os.path.join(data_root, noisy_pose_dir, f\"{seq}_{idx:06d}_{i}-pose.txt\" ) for i in range(NUM_RENDERED) ]\n            pose_noisy_rendered = [np.loadtxt(p, skiprows=1).reshape(3,4) for p in pose_noisy_rendered ]\n\n            # meta_data=sio.loadmat(meta_path)\n            # assert meta_data[\"boxes\"].shape[0] == 1\n            #generate data pairs\n            for noisy_data_idx in range(NUM_RENDERED):\n                if with_assertion:\n                    assert os.path.exists(os.path.join(data_root, rgb_orig) ), os.path.join(data_root, rgb_orig) \n                    assert os.path.exists(os.path.join(data_root, depth_orig) ), os.path.join(data_root, depth_orig) \n                    assert os.path.exists(os.path.join(data_root, label_orig) ), os.path.join(data_root, label_orig) \n                    assert os.path.exists(os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx]) ), os.path.join(data_root, rgb_noisy_rendered[noisy_data_idx]) \n                    assert os.path.exists(os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] ) ), os.path.join(data_root, depth_noisy_rendered[noisy_data_idx] ) \n\n                info = {\n                    \"index\": idx,\n                    # \"rgb_orig_path\": rgb_orig,\n                    \"rgb_observed_path\": rgb_orig,\n                    \"depth_observed_path\": depth_orig,\n                    \"depth_gt_observed_path\": depth_rendered,\n                    \"gt_pose\": gt_pose,\n\n                    \"rgb_noisy_rendered\": rgb_noisy_rendered[noisy_data_idx],\n                    \"depth_noisy_rendered\": depth_noisy_rendered[noisy_data_idx],\n                    \"pose_noisy_rendered\": pose_noisy_rendered[noisy_data_idx],\n\n                    \"model_points_path\": f\"{seq}.bin\",\n                    #legacy\n                    \"RT\": gt_pose,\n                    \"K\": linemod_config.linemod_K,\n                    # \"boxes\":  meta_data[\"boxes\"]\n                }\n                res[seq].append(info)\n\n                print(info['rgb_observed_path'], info['rgb_noisy_rendered'])\n    \n    # train_saving_path=saving_path+'.train'\n    train_saving_path=saving_path+'.eval'\n    with open(train_saving_path, 'wb+') as f:\n        print(\"Total data amount:\", np.sum([len(res[r]) for r in res]))\n        pickle.dump(res, f)\n\n    # eval_saving_path=saving_path+'.eval'\n    # with open(eval_saving_path, 'wb+') as f:\n    #     print(\"Total data amount:\", np.sum([len(test_res[r]) for r in test_res]))\n    #     pickle.dump(test_res, f)\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/generate_data_info_v2_deepim.py",
    "content": "#The version compatible with deepim   \nimport os\nimport numpy as np\nimport copy\nimport pickle\nimport fire\nimport glob\nimport re\n\n\ndef parse_pose_file(file):\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n\n    poses = []\n    for line in lines:\n        poses.append(\n            np.array([np.float32(l) for l in line.split()],\n                     dtype=np.float32).reshape((3, 4))\n        )\n    return poses\n\n\ndef parse_calib_file(file):\n    info = {}\n    with open(file) as f:\n        lines = [line.strip() for line in f.readlines()]\n    for i, l in enumerate(lines):\n        nums = np.array([np.float32(x)\n                         for x in l.split(' ')[1:]], dtype=np.float32)\n        if i < 4:\n            info[f\"calib/P{i}\"] = nums.reshape((3, 4))\n        else:\n            info[f\"calib/Tr_velo_to_cam\"] = nums.reshape((3, 4))\n    return info\n\n\n# def create_data_info(data_root, saving_path, is_test_data=False):\n# def create_data_info(data_root, saving_path, data_type='train'):\ndef create_data_info(data_root, saving_path, training_data_ratio=0.8, shuffle=True, ):\n    \"\"\"[summary]\n        info structure:\n        {\n            0:[\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n                },\n                {\n                \"index\": idx,\n                \"lidar_bin_path\": lidar_bin_paths[idx],\n                \"RT\": poses[idx],\n                \"K\": poses[idx],\n               \n                },\n            }\n            ...\n            ],\n            1:[\n\n            ]\n            ...\n        }\n\n    \"\"\"\n\n    image_dir = os.path.join(data_root, )\n    pose_dir = os.path.join(data_root )\n    depth_dir = os.path.join(data_root)\n    # blender_to_bop_pose=np.load(\"/DATA/yxu/LINEMOD/metricpose/blender2bop_RT.npy\", allow_pickle=True).flat[0]\n    blender_to_bop_pose=np.load(f\"{os.path.dirname(os.path.abspath(__file__)) }/../EXPDATA/init_poses/pose_conversion/blender2bop_RT.npy\", allow_pickle=True).flat[0]\n    # seqs=['cat']\n    seqs=['cat', 'ape', 'cam', 'duck', 'glue', 'iron', 'phone','benchvise', 'can', 'driller', 'eggbox', 'holepuncher', 'lamp']\n    print(seqs)\n    max_items_per_seq=10000\n    # create training data\n    res = {}\n    eval_res = {}\n    for seq in seqs:\n        res[seq] = []\n        eval_res[seq]=[]\n\n        image_path_dir = os.path.join(image_dir, seq)\n        depth_path_dir = os.path.join(depth_dir, seq)\n        pose_path_dir = os.path.join(pose_dir, seq)\n        # image_paths = os.listdir(lidar_bin_dir)\n        image_paths = glob.glob(r'{}/*.jpg'.format(image_path_dir) ) \n        depth_paths = glob.glob(r'{}/*depth*.npy'.format(depth_path_dir) ) \n\n        pose_paths = glob.glob(r'{}/*RT.pkl'.format(pose_path_dir) ) \n        #for compatibility\n        if len(pose_paths) ==0:\n            pose_paths = glob.glob(r'{}/*params.pkl'.format(pose_path_dir) ) \n\n\n        # image_paths.sort(key=lambda s: int(os.path.basename(s).split('.')[0]) )\n\n        image_paths.sort(key=lambda s: int(re.split( '\\.|_' ,os.path.basename(s))[0]) )\n        depth_paths.sort(key=lambda s: int(os.path.basename(s).split('_')[0]))\n        pose_paths.sort(key=lambda s: int(os.path.basename(s).split('_')[0]))\n\n\n        data_num=len(image_paths[:max_items_per_seq] ) \n        if shuffle: \n            \n            permute=np.random.permutation(data_num)\n        else:\n            permute = np.arange(data_num)\n        train_split=permute[:int(data_num*training_data_ratio)]\n        eval_split= permute[int(data_num*training_data_ratio):]\n\n        # for idx in range(len(image_paths[:max_items_per_seq])):\n        for idx in train_split:\n            # print(image_paths[idx], depth_paths[idx])\n            with open(pose_paths[idx],'rb') as f:\n                pose = pickle.load(f) \n\n            # pose['K'] = np.array([[572.4114, 0., 325.2611],\n            #                   [0., 573.57043, 242.04899],\n            #                   [0., 0., 1.]])\n            if seq=='cam':\n                bl2bo = blender_to_bop_pose['camera']\n            else:\n                bl2bo = blender_to_bop_pose[seq]\n\n            pose[\"RT\"][:3,:3] =  pose[\"RT\"][:3,:3]@bl2bo[:3,:3].T\n            # pose[\"RT\"][:3,3:] =  -pose[\"RT\"][:3,:3]@bl2bo[:3,:3].T @bl2bo[:3,3:]  + pose[\"RT\"][:3,3:] \n            pose[\"RT\"][:3,3:] =  -pose[\"RT\"][:3,:3] @bl2bo[:3,3:]  + pose[\"RT\"][:3,3:] \n            info = {\n                \"index\": idx,\n                # \"image_path\": image_paths[idx].replace(image_dir+'/',''),\n                # \"depth_path\": depth_paths[idx].replace(depth_dir+'/',''),\n                \"rgb_observed_path\": image_paths[idx].replace(image_dir,'./'),\n                \"depth_gt_observed_path\": depth_paths[idx].replace(depth_dir,'./'),\n                \"rgb_noisy_rendered\": None,\n                \"depth_noisy_rendered\": None,\n                \"pose_noisy_rendered\": None,\n                \"model_points_path\": f\"{seq}.bin\",\n                # \"RT\": pose[\"RT\"],\n                \"gt_pose\":  pose[\"RT\"],\n                \"K\": pose[\"K\"],\n                \"bbox\":pose.get('bbox', None)\n            }\n\n            print(info['rgb_observed_path'], info['depth_gt_observed_path'], image_dir)#, bl2bo[:3,:3])\n            res[seq].append(info)\n\n\n    train_saving_path=saving_path+'.train'\n    # eval_saving_path=saving_path+'.eval'\n    with open(train_saving_path, 'wb+') as f:\n\n        print(\"Total data amount:\", np.sum([len(res[r]) for r in res]))\n        pickle.dump(res, f)\n\n    # with open(eval_saving_path, 'wb+') as f:\n\n    #     print(\"Total data amount:\", np.sum([len(eval_res[r]) for r in eval_res]))\n    #     pickle.dump(eval_res, f)\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/train.py",
    "content": "import numpy as np \nimport torch\n\nfrom pathlib import Path\nimport json\nimport random\nimport re\nimport torch.backends.cudnn as cudnn\nimport torch.multiprocessing as mp\nimport time\nimport fire\nimport torch.distributed as dist\nimport os\nfrom collections import defaultdict\nimport kornia\nimport flow_vis\nimport copy\nimport ast \n\nfrom utils.progress_bar import ProgressBar\nfrom utils.log_tool import SimpleModelLog\nfrom data.preprocess import merge_batch, get_dataloader_deepim  # merge_second_batch_multigpu\nfrom utils.config_io import merge_cfg, save_cfg\nimport torchplus\nfrom builder import (\n    dataset_builder,\n    input_reader_builder,\n    lr_scheduler_builder,\n    optimizer_builder,\n    rnnpose_builder\n)\nfrom utils.distributed_utils import dist_init, average_gradients, DistModule, ParallelWrapper, DistributedSequatialSampler, DistributedGivenIterationSampler, DistributedGivenIterationSamplerEpoch \n# from utils.visualize import vis_pointclouds_cv2, vis_2d_keypoints_cv2\nfrom utils.util import modify_parameter_name_with_map\nfrom config.default import get_cfg\nfrom data.ycb.basic import bop_ycb_class2idx\n\n\nGLOBAL_GPUS_PER_DEVICE = 1  # None\nGLOBAL_STEP = 0\nRANK=-1\nWORLD_SIZE=-1\n\n\ndef load_example_to_device(example,\n                             device=None) -> dict:\n    # global GLOBAL_GPUS_PER_DEVICE\n    # device = device % GLOBAL_GPUS_PER_DEVICE or torch.device(\"cuda:0\")\n    # example_torch = defaultdict(list)\n\n    example_torch = {}\n\n    for k, v in example.items():  \n        if k in ['class_name', 'idx'] or example[k] is None:\n            example_torch[k] = v\n            continue\n\n        if type(v) == list:\n            example_torch[k] = [item.to(device=device) for item in v]\n        else:\n            example_torch[k] = v.to(device=device)\n\n    return example_torch\ndef build_network(model_cfg, measure_time=False, testing=False):\n    net = rnnpose_builder.build(\n        model_cfg, measure_time=measure_time, testing=testing)\n    return net\n\n\ndef _worker_init_fn(worker_id):\n    global GLOBAL_STEP\n    time_seed = GLOBAL_STEP\n    np.random.seed(time_seed + worker_id)\n    print(f\"WORKER {worker_id} seed:\", np.random.get_state()[1][0])\n\n\ndef freeze_params(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    remain_params = []\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                continue\n        remain_params.append(p)\n    return remain_params\n\n\ndef freeze_params_v2(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                p.requires_grad = False\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                p.requires_grad = False\n\n\ndef filter_param_dict(state_dict: dict, include: str = None, exclude: str = None):\n    assert isinstance(state_dict, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    res_dict = {}\n    for k, p in state_dict.items():\n        if include_re is not None:\n            if include_re.match(k) is None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is not None:\n                continue\n        res_dict[k] = p\n    return res_dict\n\ndef chk_rank(rank_, use_dist=False):\n    if not use_dist:\n        return True\n    global RANK\n    if RANK<0:\n        RANK=dist.get_rank()\n    cur_rank = RANK#dist.get_rank()\n    # self.world_size = dist.get_world_size()\n    return cur_rank == rank_\n\ndef get_rank(use_dist=False):\n    if not use_dist:\n        return 0\n    else:\n        # return dist.get_rank()\n        global RANK \n        if RANK<0:\n            RANK=dist.get_rank()\n        return RANK \n\ndef get_world(use_dist):\n    if not use_dist:\n        return 1\n    else:\n        global WORLD_SIZE \n        if WORLD_SIZE<0:\n            WORLD_SIZE=dist.get_world_size()\n        return WORLD_SIZE #dist.get_world_size()\ndef get_ngpus_per_node():\n    global GLOBAL_GPUS_PER_DEVICE\n    return GLOBAL_GPUS_PER_DEVICE\ndef get_logger():\n    logger_name = \"main-logger\"\n    logger = logging.getLogger(logger_name)\n    logger.setLevel(logging.INFO)\n    handler = logging.StreamHandler()\n    fmt = \"[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s\"\n    handler.setFormatter(logging.Formatter(fmt))\n    logger.addHandler(handler)\n    return logger\n\n\n\ndef multi_proc_train(\n          config_path,\n          model_dir,\n          use_apex,\n          world_size,\n          result_path=None,\n          create_folder=False,\n          display_step=50,\n          summary_step=5,\n          pretrained_path=None,\n          pretrained_include=None,\n          pretrained_exclude=None,\n          pretrained_param_map=None,\n          freeze_include=None,\n          freeze_exclude=None,\n          measure_time=False,\n          resume=False,\n          use_dist=False,\n          gpus_per_node=1,\n          start_gpu_id=0,\n          optim_eval=False,\n          seed=7,\n          dist_port=\"23335\",\n         force_resume_step=None,\n         batch_size=None,\n         apex_opt_level='O0'\n          ):\n    \n    params = {\n          \"config_path\": config_path,\n          \"model_dir\": model_dir,\n          \"use_apex\": use_apex,\n          \"result_path\": result_path,\n          \"create_folder\": create_folder,\n          \"display_step\": display_step,\n          \"summary_step\": summary_step,\n          \"pretrained_path\": pretrained_path,\n          \"pretrained_include\": pretrained_include,\n          \"pretrained_exclude\": pretrained_exclude,\n          \"pretrained_param_map\": pretrained_param_map,\n          \"freeze_include\": freeze_include,\n          \"freeze_exclude\": freeze_exclude,\n        #   \"multi_gpu\": multi_gpu,\n          \"measure_time\": measure_time,\n          \"resume\": resume,\n          \"use_dist\": use_dist,\n          \"gpus_per_node\": gpus_per_node,\n          \"optim_eval\": optim_eval,\n          \"seed\": seed,\n          \"dist_port\": dist_port,\n          \"world_size\": world_size,\n          \"force_resume_step\":force_resume_step,\n          \"batch_size\": batch_size,\n          \"apex_opt_level\":apex_opt_level,\n          \"start_gpu_id\": start_gpu_id\n    }\n    from types import SimpleNamespace \n    params = SimpleNamespace(**params)\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(\n        str(x) for x in range(start_gpu_id, start_gpu_id+gpus_per_node))\n    print(f\"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}\"  )\n\n    mp.spawn(train_worker, nprocs=gpus_per_node,\n                args=( params,) )\n\ndef train_worker(rank, params):\n    global RANK, WORLD_SIZE\n    RANK = rank\n    WORLD_SIZE=params.world_size\n\n    # import os\n    # os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(params.start_gpu_id+rank) \n    print(f\"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}\")\n    torch.cuda.set_device(rank%params.gpus_per_node)\n    \n    train(config_path=params.config_path,\n          model_dir=params.model_dir,\n          use_apex=params.use_apex,\n          result_path=params.result_path,\n          create_folder=params.create_folder,\n          display_step=params.display_step,\n          pretrained_path=params.pretrained_path,\n          pretrained_include=params.pretrained_include,\n          pretrained_exclude=params.pretrained_exclude,\n          pretrained_param_map=params.pretrained_param_map,\n          freeze_include=params.freeze_include,\n          freeze_exclude=params.freeze_exclude,\n        #   multi_gpu=params.multi_gpu,\n          measure_time=params.measure_time,\n          resume=params.resume,\n          use_dist=params.use_dist,\n          dist_port=params.dist_port,\n          gpus_per_node=params.gpus_per_node,\n          optim_eval=params.optim_eval,\n          seed=params.seed,\n          force_resume_step=params.force_resume_step,\n          batch_size = params.batch_size,\n          apex_opt_level=params.apex_opt_level,\n          gpu_id=params.start_gpu_id+rank, \n          ) \n\n\ndef train(\n         config_path,\n          model_dir,\n          use_apex,\n          result_path=None,\n          create_folder=False,\n          display_step=50,\n          summary_step=5,\n          pretrained_path=None,\n          pretrained_include=None,\n          pretrained_exclude=None,\n          pretrained_param_map=None,\n          freeze_include=None,\n          freeze_exclude=None,\n          multi_gpu=False,\n          measure_time=False,\n          resume=False,\n          use_dist=False,\n          dist_port=\"23335\",\n          gpus_per_node=1,\n          optim_eval=False,\n          seed=7,\n          force_resume_step=None,\n          batch_size=None,\n          apex_opt_level='O0',\n          gpu_id=None\n          ):\n    \"\"\"train a VoxelNet model specified by a config file.\n    \"\"\"\n\n    print(\"force_resume_step:\", force_resume_step)\n    print(\"torch.cuda.is_available()=\", torch.cuda.is_available())\n    print(\"torch.version.cuda=\",torch.version.cuda) \n    dist_url=f\"tcp://127.0.0.1:{dist_port}\"\n    print(f\"dist_url={dist_url}\", flush=True)\n    global RANK, WORLD_SIZE\n    # RANK, WORLD_SIZE=rank, world_size\n    if RANK<0:\n        RANK=0\n    if WORLD_SIZE<0:\n        WORLD_SIZE=1\n\n    global GLOBAL_GPUS_PER_DEVICE\n    GLOBAL_GPUS_PER_DEVICE = gpus_per_node\n\n    #fix the seeds \n    print(f\"Set seed={seed}\", flush=True)\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n\n\n  \n    ######################################## initialize the distributed env #########################################\n    if use_dist:\n        # torch.cuda.set_device(get_rank(use_dist))\n        if use_apex:\n            dist.init_process_group(\n                backend=\"nccl\", init_method=dist_url, world_size=get_world(use_dist), rank=get_rank(use_dist))\n        else:\n            # rank, world_size = dist_init(str(dist_port))\n            dist.init_process_group(\n                backend=\"nccl\", init_method=dist_url, world_size=get_world(use_dist), rank=get_rank(use_dist))\n    \n\n    ############################################ create folders ############################################\n\n    model_dir = str(Path(model_dir).resolve())\n    model_dir = Path(model_dir)\n    if chk_rank(0, use_dist):\n        if not resume and model_dir.exists():\n            raise ValueError(\"model dir exists and you don't specify resume.\")\n            print(\"Warning: model dir exists and you don't specify resume.\")\n\n        model_dir.mkdir(parents=True, exist_ok=True)\n    if result_path is None:\n        result_path = model_dir / 'results'\n    config_file_bkp = \"pipeline.config\"\n\n    ############################################# read config proto ############################################\n    config = merge_cfg(\n        [config_path], intersection=True)\n    if chk_rank(0, use_dist):\n        print(json.dumps(config, indent=4))\n\n    if chk_rank(0, use_dist):\n        # save_cfg([default_config_path, custom_config_path],\n        save_cfg([config_path, config_path],\n                 str(model_dir / config_file_bkp))\n    #update the global config object\n    get_cfg().merge(config.get(\"BASIC\",{}),\"BASIC\" )  \n\n    input_cfg = config.train_input_reader\n    eval_input_cfg = config.eval_input_reader\n    model_cfg = config.model\n    train_cfg = config.train_config\n    optimizer_cfg = train_cfg.optimizer\n    loss_scale = train_cfg.loss_scale_factor\n\n\n    ############################################# Update default options ############################################\n\n    if batch_size is not None:\n        input_cfg.batch_size = batch_size \n    print(input_cfg.batch_size)\n    \n    ############################################# build network, optimizer etc. ############################################\n    #dummy dataset to get obj_seqs\n    dataset_tmp = input_reader_builder.build(\n        input_cfg,\n        training=True,\n    )\n \n    model_cfg.obj_seqs = copy.copy(dataset_tmp._dataset.infos[\"seqs\"])\n    print(model_cfg.obj_seqs, type(model_cfg.obj_seqs), flush=True)\n\n    model_cfg.gpu_id=gpu_id\n    net = build_network(model_cfg, measure_time)  #.to(device)\n    net.cuda()\n\n    fastai_optimizer = optimizer_builder.build(\n        optimizer_cfg, net,\n        mixed=False,\n        loss_scale=loss_scale)\n\n    print(\"num parameters:\", len(list(net.parameters())))\n\n    ############################################# load pretrained model ############################################\n    if pretrained_path is not None:\n        model_dict = net.state_dict()\n        pretrained_dict = torch.load(pretrained_path, map_location='cpu')\n        print(\"Pretrained keys:\", pretrained_dict.keys())\n        print(\"Model keys:\", model_dict.keys())\n\n        pretrained_dict = filter_param_dict(\n            pretrained_dict, pretrained_include, pretrained_exclude)\n\n        pretrained_dict = modify_parameter_name_with_map(\n            pretrained_dict, ast.literal_eval(str(pretrained_param_map)))\n        new_pretrained_dict = {}\n        for k, v in pretrained_dict.items():\n            if k in model_dict and v.shape == model_dict[k].shape:\n                new_pretrained_dict[k] = v\n        print(\"Load pretrained parameters:\")\n        for k, v in new_pretrained_dict.items():\n            print(k, v.shape)\n        model_dict.update(new_pretrained_dict)\n        net.load_state_dict(model_dict)\n        freeze_params_v2(dict(net.named_parameters()),\n                         freeze_include, freeze_exclude)\n        net.clear_global_step()\n        # net.clear_metrics()\n        del pretrained_dict\n    ############################################# try to resume from the latest chkpt ############################################\n    torchplus.train.try_restore_latest_checkpoints(model_dir, [net])\n    torchplus.train.try_restore_latest_checkpoints(model_dir,\n                                                   [fastai_optimizer])\n\n    ######################################## parallel the network  #########################################\n    if use_dist:\n        if use_apex:\n            import apex\n            net = apex.parallel.convert_syncbn_model(net)\n            net, amp_optimizer = apex.amp.initialize(net.cuda(\n            ), fastai_optimizer, opt_level=\"O0\", keep_batchnorm_fp32=None, loss_scale=None)\n            net_parallel = apex.parallel.DistributedDataParallel(net)\n        else:\n            amp_optimizer = optimizer_builder.build(\n                optimizer_cfg, net,\n                mixed=False,\n                loss_scale=loss_scale)\n            net_parallel = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_rank(use_dist)], output_device=get_rank(use_dist) ,find_unused_parameters=True)\n    else:\n        net_parallel = net.cuda()\n\n    ############################################# build lr_scheduler ############################################\n    lr_scheduler = lr_scheduler_builder.build(optimizer_cfg, amp_optimizer,\n                                              train_cfg.steps)\n\n    if 0:  # TODO:\n        float_dtype = torch.float16\n    else:\n        float_dtype = torch.float32\n    ######################################## build dataloaders #########################################\n    if use_dist:\n        num_gpu = 1\n        collate_fn = merge_batch\n    else:\n        raise NotImplementedError\n    print(f\"MULTI-GPU: use {num_gpu} gpu\")\n\n        \n    ######################\n    # PREPARE INPUT\n    ######################\n    dataset = input_reader_builder.build(\n        input_cfg,\n        training=True,\n    )\n\n    eval_dataset = input_reader_builder.build(\n        eval_input_cfg,\n        training=False,\n    )\n\n\n    if use_dist:\n        train_sampler = DistributedGivenIterationSamplerEpoch(\n            dataset, train_cfg.steps, input_cfg.batch_size, last_iter=net.get_global_step()-1, review_cycle=-1)\n        shuffle = False\n        eval_sampler = DistributedSequatialSampler(eval_dataset)\n\n    else:\n        train_sampler = None\n        eval_sampler = None\n        eval_train_sampler = None\n        shuffle = True\n    \n    dataloader, neighborhood_limits=get_dataloader_deepim(dataset=dataset,\n                                                kpconv_config=model_cfg.descriptor_net.keypoints_detector_3d,\n                                                batch_size=input_cfg.batch_size * num_gpu,\n                                                shuffle=shuffle,\n                                                num_workers=input_cfg.preprocess.num_workers * num_gpu,\n                                                sampler=train_sampler \n                                        )\n    eval_dataloader, _ =get_dataloader_deepim(dataset=eval_dataset,\n                                                kpconv_config=model_cfg.descriptor_net.keypoints_detector_3d,\n                                    batch_size=eval_input_cfg.batch_size,\n                                    shuffle=False,\n                                    num_workers=eval_input_cfg.preprocess.num_workers,\n                                    sampler=eval_sampler,\n                                    neighborhood_limits=neighborhood_limits\n                                   )\n\n    ##########################################################################################\n    #                                            TRAINING\n    ##########################################################################################\n    model_logging = SimpleModelLog(model_dir, disable=get_rank(use_dist) != 0)\n    model_logging.open()\n    start_step = net.get_global_step()\n    total_step = train_cfg.steps\n    t = time.time()\n    steps_per_eval = train_cfg.steps_per_eval\n\n    amp_optimizer.zero_grad()\n    step_times = []\n    step = start_step\n    epoch = 0\n    net_parallel.train()\n\n    try:\n        while True:\n\n            if use_dist:\n                epoch = (net.get_global_step() *\n                         input_cfg.batch_size) // len(dataloader)\n\n                dataloader.sampler.set_epoch(epoch)\n            else:\n                epoch += 1\n            for example in dataloader:\n\n                global GLOBAL_STEP\n                GLOBAL_STEP = step\n\n                lr_scheduler.step(net.get_global_step())\n                example_torch=load_example_to_device(example, device=torch.device(\"cuda\"))\n\n                batch_size = example[\"image\"].shape[0]\n\n                ret_dict = net_parallel(example_torch)\n\n                loss = ret_dict[\"loss\"].mean()  # /get_world(use_dist)\n                recall = ret_dict['recall'].mean()\n\n                reduced_loss = loss.data.clone() / get_world(use_dist)\n                reduced_recall = recall.data.clone() / get_world(use_dist)\n                if use_dist:\n                    dist.all_reduce_multigpu(\n                        [reduced_loss])\n                    dist.all_reduce_multigpu(\n                        [reduced_recall])\n\n                amp_optimizer.zero_grad()\n                if use_apex:\n                    with apex.amp.scale_loss(loss, amp_optimizer) as scaled_loss:\n                        scaled_loss.backward()\n                else:\n                    loss = loss/get_world(use_dist)\n                    loss.backward()\n                    if use_dist:\n                        average_gradients(net_parallel)\n\n\n                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)\n                amp_optimizer.step()\n\n                net.update_global_step()\n                step_time = (time.time() - t)\n                step_times.append(step_time)\n                t = time.time()\n                metrics = defaultdict(dict)\n                GLOBAL_STEP = net.get_global_step()\n\n                if chk_rank(0, use_dist) and GLOBAL_STEP % display_step == 0:\n                    print(f'Model directory: {str(model_dir)}')\n                    if measure_time:\n                        for name, val in net.get_avg_time_dict().items():\n                            print(f\"avg {name} time = {val * 1000:.3f} ms\")\n\n                    metrics[\"runtime\"] = {\n                        \"step\": GLOBAL_STEP,\n                        \"steptime\": np.mean(step_times),\n                    }\n\n                    metrics[\"loss\"][\"loss\"] = float(\n                        reduced_loss.detach().cpu().numpy())\n                    # metrics[\"loss\"][\"translation_loss\"] = float(\n                    metrics[\"recall\"] = float(reduced_recall.detach().cpu().numpy())\n\n                    metrics['learning_rate'] = amp_optimizer.lr\n                    metrics['reproj_loss'] = float(ret_dict['reproj_loss'].median().detach().cpu().numpy())\n                    metrics['loss_3d_proj'] = float(ret_dict['loss_3d_proj'].median().detach().cpu().numpy())\n                    # metrics['chamfer_loss'] = float(ret_dict['chamfer_loss'].median().detach().cpu().numpy())\n                    if hasattr(net.motion_net, \"sigma\"):\n                        metrics['sigma'] = float(net.motion_net.sigma[0].detach().cpu().numpy())\n\n                    metrics['epoch'] = epoch\n\n                    model_logging.log_metrics(metrics, GLOBAL_STEP)\n\n\n                    \n                    if isinstance(ret_dict[\"flow\"], (list, tuple) ):\n                        ret_dict[\"flow\"] = ret_dict[\"flow\"][-1]\n                    flow=flow_vis.flow_to_color(ret_dict[\"flow\"][0].squeeze().permute(1,2,0).detach().cpu().numpy(), convert_to_bgr=False)\n\n                    model_logging.log_images(\n                    {\n                        \"image\": example[\"image\"].cpu(),\n                        \"flow\": flow.transpose(2,0,1)[None],\n                        \"weight\": ret_dict[\"weight\"].squeeze(1).mean(1,keepdims=True).detach().cpu(),\n                        \"syn_img\": torch.cat(ret_dict[\"syn_img\"], dim=0)[:,:3].detach().cpu(),\n                        \"syn_depth\": (ret_dict[\"syn_depth\"][-1]/ret_dict[\"syn_depth\"][-1].max()).detach().cpu(),\n                        \"valid_mask\": ret_dict[\"valid_mask\"].detach().squeeze(1).permute(0,3,1,2).cpu(),\n                        \"ren_mask\": example[\"ren_mask\"][:,None].cpu()\n                    }, GLOBAL_STEP, prefix=\"\") \n\n                if optim_eval and GLOBAL_STEP < total_step/2 and GLOBAL_STEP > train_cfg.steps_per_eval*2:\n                    steps_per_eval = 2*train_cfg.steps_per_eval\n                else:\n                    steps_per_eval = train_cfg.steps_per_eval\n\n                if GLOBAL_STEP % steps_per_eval == 0:\n                    if chk_rank(0, use_dist):  # logging\n                        torchplus.train.save_models(model_dir, [net, amp_optimizer],\n                                                    net.get_global_step())\n                    eval_once(net,\n                              eval_dataset=eval_dataset, eval_dataloader=eval_dataloader, eval_input_cfg=eval_input_cfg,\n                              result_path=result_path,\n                              global_step=GLOBAL_STEP,\n                              model_logging=model_logging,\n                              metrics=metrics,\n                              float_dtype=float_dtype,\n                              use_dist=use_dist,\n                              prefix='eval_')\n\n                    net.train()\n\n                step += 1\n                if step >= total_step:\n                    break\n            if step >= total_step:\n                break\n    except Exception as e:\n        model_logging.log_text(str(e), step)\n        raise e\n    finally:\n        model_logging.close()\n    torchplus.train.save_models_cpu(model_dir, [net, amp_optimizer],\n                                    net.get_global_step())\n\n\n\n\n\ndef eval_once(net,\n              eval_dataset, eval_dataloader, eval_input_cfg,\n              result_path, global_step, model_logging, metrics, float_dtype, use_dist, prefix='eval_'):\n    from utils.eval_metric import LineMODEvaluator#, YCBEvaluator\n\n    net.eval()\n    result_path_step = result_path / \\\n        f\"step_{global_step}\"\n    if chk_rank(0, use_dist):\n        result_path_step.mkdir(parents=True, exist_ok=True)\n        model_logging.log_text(\"#################################\",\n                               global_step)\n        model_logging.log_text(\"# EVAL\", global_step)\n        model_logging.log_text(\"#################################\",\n                               global_step)\n        model_logging.log_text(\n            \"Generate output labels...\", global_step)\n        prog_bar = ProgressBar()\n        prog_bar.start(len(eval_dataloader))\n\n    # RES={\n    #     # \"recall\":[]\n    #     \"recall\":[],\n    #     \"3d_proj_error\":[],\n    #     \"2d_proj_error\":[],\n    # }\n    t = 0\n    detections = []\n    \n    #construct evaluator\n    classes=list(set(eval_dataset.dataset.infos['seqs']))\n    if classes[0] in bop_ycb_class2idx.keys():\n        evaluator = dict([(c,YCBEvaluator(f\"{c}\",result_path) ) for c in classes ] )\n    else:\n        evaluator = dict([(c,LineMODEvaluator(f\"{c}\",result_path) ) for c in classes ] )\n\n    # cnt = 0  \n    for i, example in enumerate(eval_dataloader):\n        example_torch=load_example_to_device(example, device=torch.device(\"cuda\"))\n        batch_size = example[\"image\"].shape[0]\n        device = example_torch[\"image\"].device\n        with torch.no_grad():\n            ret_dict = net(example_torch)\n\n        # RES['recall'].append(recall.detach())\n        # RES['3d_proj_error'].append(ret_dict['loss_3d_proj'].mean().detach() )\n        # RES['2d_proj_error'].append(ret_dict['reproj_loss'].mean().detach() )\n\n        #do evaluation\n        evaluator[example['class_name'][0]].evaluate_rnnpose(ret_dict, example)\n\n\n        if chk_rank(0, use_dist) and i % 10 == 0:\n            prog_bar.print_bar(finished_size=10)\n    eval_res={} \n    for k in evaluator:\n        eval_res[k]=evaluator[k].summarize()\n\n    if use_dist:  # chk_rank(0, use_dist):\n        #gather and summarize evaluation results\n        for k in eval_res:\n            eval_res[k]['seq_len'] =  torch.tensor(eval_res[k]['seq_len'], device=device)\n            gather_list = [torch.zeros_like(eval_res[k]['seq_len'])\n                           for i in range(get_world(use_dist))]\n            dist.all_gather(gather_list, eval_res[k]['seq_len'])\n            seq_len = torch.stack(gather_list, dim=-1).sum()\n\n            for kk in eval_res[k]:\n                if kk =='seq_len':\n                    continue\n                eval_res[k][kk] =  torch.tensor(eval_res[k][kk], device=device)\n                eval_res[k][kk] =  eval_res[k][kk]*eval_res[k]['seq_len'] # calculate the sum \n                gather_list = [torch.zeros_like(eval_res[k][kk] )\n                           for i in range(get_world(use_dist))]\n                dist.all_gather(gather_list, eval_res[k][kk])\n                eval_res[k][kk] = torch.stack(gather_list, dim=-1).sum()/seq_len #divide with the whole length to get the mean value \n\n\n\n    if chk_rank(0, use_dist):\n        for k in eval_res:\n            for kk in eval_res[k]:\n                if kk == 'seq_len':\n                    metrics[f'{k}_{kk}'] = float(seq_len.cpu().numpy() )\n                metrics[f'{k}_{kk}'] = float(eval_res[k][kk].cpu().numpy())\n\n\n    if chk_rank(0, use_dist):\n        model_logging.log_metrics(metrics, global_step)\n    # del RES \n    del eval_res\n    net.train()\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "tools/transform_data_format.py",
    "content": "import numpy as np \nimport cv2\nimport pickle\nimport fire\nimport os\nimport argparse\n\n\nlinemod_K = np.array([[572.4114, 0., 325.2611],\n                  [0., 573.57043, 242.04899],\n                  [0., 0., 1.]])\n\n\nblender_K = np.array([[700., 0., 320.],\n                    [0., 700., 240.],\n                    [0., 0., 1.]])\n\n\ndef range_to_depth(mask, range, K):\n    '''\n       Transform the range image to depth image\n    '''\n    f=K[0,0]\n    cx=K[0,2]\n    cy=K[1,2]\n\n    ys_, xs_=np.nonzero(mask)\n    rngs=range[ys_,xs_]\n    xs,ys=np.asarray(xs_,np.float32)+0.5,np.asarray(ys_,np.float32)+0.5\n\n    Zs=f*rngs/( f**2 + (cx-xs)**2 + (cy-ys)**2 )**0.5\n    depth = np.zeros_like(range)\n    depth[ys_,xs_] = Zs\n    return  depth\n\ndef crop(image, depth, mask, K_old, margin_ratio=0.1, output_size=128 ):\n    '''\n        image: HxWx3\n        mask: HxW\n        K_old: 3x3\n    '''\n\n    H,W, _ = image.shape\n    \n\n    mask=mask.astype('uint8')*255\n    _x,_y,_w,_h = cv2.boundingRect(mask) \n\n    center=[_x+_w/2, _y+_h/2]\n\n    L=int (max(_w,_h)* (1+2*margin_ratio))\n\n\n    x=max(0, int(center[0]- L/2) )\n    y=max(0, int(center[1]- L/2) )\n    \n    crop=image[y:y+L, x:x+L]\n    depth_crop=depth[y:y+L, x:x+L]\n\n    w=h=L # actual crop size\n\n    #automatically handle the \"out of range\" problem\n    patch=np.zeros([h,w,3], dtype=image.dtype)\n    depth_patch=np.ones([h,w], dtype=depth.dtype)\n    try:\n        xp = 0\n        yp = 0\n        patch[xp : xp+crop.shape[0], yp:yp+crop.shape[1] ] =  crop\n        depth_patch[xp : xp+crop.shape[0], yp:yp+crop.shape[1] ] = depth_crop \n    except:\n        import pdb \n        pdb.set_trace()\n    patch=cv2.resize(patch, (output_size,output_size), interpolation=cv2.INTER_LINEAR )\n    depth_patch=cv2.resize(depth_patch, (output_size,output_size), interpolation=cv2.INTER_NEAREST )\n\n    #update the intrinsic parameters\n    K_new=np.zeros_like(K_old)\n    scale=output_size/L\n    K_new[0,2] = (K_old[0,2]-x)*scale\n    K_new[1,2] = (K_old[1,2]-y)*scale\n    K_new[0,0] = K_old[0,0]*scale\n    K_new[1,1] = K_old[1,1]*scale\n    K_new[2,2] = 1\n\n    return patch, depth_patch, K_new \n\nclass DataFormatter(object):\n    def __init__(self, data_type, data_info_path, crop_param=None ):\n        assert data_type in ['LM_SYN_PVNET', \"LM_SYN_PVNET_LMK\",'LM_FUSE_PVNET','LM_FUSE_SINGLE_PVNET' ]\n        self.data_type=data_type\n        self.crop_param=crop_param\n        with open(data_info_path, 'rb') as f:\n            self.data_info=pickle.load(f)\n        pass \n\n    \n    def process(self, data_root,depth_root, save_root):\n\n        if self.data_type == \"LM_SYN_PVNET\":\n            self._proc_LM_SYN_PVNET(self.data_info, data_root, save_root)\n        elif self.data_type=='LM_SYN_PVNET_LMK':\n            self._proc_LM_SYN_PVNET_LMK(self.data_info, data_root, save_root)\n        elif self.data_type=='LM_FUSE_PVNET':\n            self._proc_LM_FUSE_PVNET(self.data_info, data_root,depth_root, save_root)\n        elif self.data_type=='LM_FUSE_SINGLE_PVNET':\n            self._proc_LM_FUSE_SINGLE_PVNET(self.data_info, data_root,depth_root, save_root)\n        else:\n            raise NotImplementedError\n\n    def _proc_LM_SYN_PVNET(self, data_info, data_root, save_root):\n\n        for seq in data_info:\n            for idx in range(len(data_info[seq]) ):\n                # info = {\n                #     \"index\": idx,\n                #     \"image_path\": image_paths[idx].replace(image_path_dir+'/',''),\n                #     \"depth_path\": depth_paths[idx].replace(depth_path_dir+'/',''),\n                #     \"RT\": pose[\"RT\"],\n                #     \"K\": pose[\"K\"],\n                # }\n                info = data_info[seq][idx]\n                # image=cv2.imread( os.path.join(data_root, seq, info['image_path']) )\n                # depth=np.load(os.path.join(data_root, seq, info['depth_path'])) \n                image=cv2.imread( os.path.join(data_root,  info['image_path']) )\n                depth=np.load(os.path.join(data_root,  info['depth_path'])) \n                # K_old = info[\"K\"]\n                K_old = blender_K.copy()\n\n                # maximum depth value = 1, which indicates the invalid regions \n\n                hs,ws=np.nonzero(depth<1)\n                hmin,hmax=np.min(hs),np.max(hs)\n                wmin,wmax=np.min(ws),np.max(ws)\n                bbox= [hmin, wmin, hmax, wmax]\n\n                mask=depth<1\n                #transform the range map to depth map \n                depth = range_to_depth(depth<1, depth*2, K_old)\n                if self.crop_param is not None:\n                    image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=self.crop_param['margin_ratio'], output_size=self.crop_param['output_size'] )\n                else:\n                    K_new = K_old\n\n\n                print(info['image_path'], info['depth_path'])\n                patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}.jpg\")\n                depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_depth.npy\")\n                pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_params.pkl\")\n\n                if not os.path.exists(os.path.join(save_root, seq)):\n                    os.makedirs(os.path.join(save_root, seq))\n                #save\n                cv2.imwrite(patch_save_path,image)\n                np.save(depth_patch_save_path, depth)\n                with open(pose_save_path, 'wb+') as f:\n                    pickle.dump({\n                        \"RT\": info[\"RT\"],\n                        \"K\": K_new,\n                        \"bbox\": bbox\n                    },f)\n\n    def _proc_LM_SYN_PVNET_LMK(self, data_info, data_root, save_root):\n        #cam intrinsic is LM \n    \n        for seq in data_info:\n            for idx in range(len(data_info[seq]) ):\n                # info = {\n                #     \"index\": idx,\n                #     \"image_path\": image_paths[idx].replace(image_path_dir+'/',''),\n                #     \"depth_path\": depth_paths[idx].replace(depth_path_dir+'/',''),\n                #     \"RT\": pose[\"RT\"],\n                #     \"K\": pose[\"K\"],\n                # }\n                info = data_info[seq][idx]\n\n                image=cv2.imread( os.path.join(data_root,  info['image_path']) )\n                depth=np.load(os.path.join(data_root,  info['depth_path'])) \n                with open( os.path.join(data_root,  info['image_path'].replace(\".jpg\", \"_RT.pkl\")), 'rb' ) as f:\n                    old_params=pickle.load(f)\n                # K_old = info[\"K\"]\n                K_old = old_params[\"K\"] #linemod_K.copy()\n\n                # maximum depth value = 1, which indicates the invalid regions \n\n                hs,ws=np.nonzero(depth<1)\n                hmin,hmax=np.min(hs),np.max(hs)\n                wmin,wmax=np.min(ws),np.max(ws)\n                bbox= [hmin, wmin, hmax, wmax]\n\n                mask=depth<1\n                #transform the range map to depth map \n                depth = range_to_depth(depth<1, depth*2, K_old)\n                if self.crop_param is not None:\n                    image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=self.crop_param['margin_ratio'], output_size=self.crop_param['output_size'] )\n                else:\n                    K_new = K_old\n\n\n                print(info['image_path'], info['depth_path'], \"...\")\n                patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}.jpg\")\n                depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_depth.npy\")\n                pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_params.pkl\")\n\n                if not os.path.exists(os.path.join(save_root, seq)):\n                    os.makedirs(os.path.join(save_root, seq))\n                #save\n                cv2.imwrite(patch_save_path,image)\n                np.save(depth_patch_save_path, depth)\n                with open(pose_save_path, 'wb+') as f:\n                    pickle.dump({\n                        \"RT\": old_params[\"RT\"] ,#info[\"RT\"],\n                        \"K\": K_new,\n                        \"bbox\": bbox\n                    },f)\n\n    def _proc_LM_FUSE_PVNET(self, data_info, data_root, depth_root, save_root):\n    \n        # The class name list used during the fusing process, which is used to find the respective mask index  \n        linemod_cls_names=['ape','cam','cat','duck','glue','iron','phone', 'benchvise','can','driller','eggbox','holepuncher','lamp']\n\n        for seq in data_info:\n            seq_idx = linemod_cls_names.index(seq)\n            for idx in range(len(data_info[seq]) ):\n                # info = {\n                #     \"index\": idx,\n                #     \"image_path\": image_paths[idx].replace(image_path_dir+'/',''),\n                #     \"depth_path\": depth_paths[idx].replace(depth_path_dir+'/',''),\n                #     \"RT\": pose[\"RT\"],\n                #     \"K\": pose[\"K\"],\n                # }\n                info = data_info[seq][idx]\n                # if info['image_path'] =='cat/2744.jpg':\n                #     info =  data_info[seq][idx+1]\n\n                with open(os.path.join(data_root,  info['image_path']).split('.jpg')[0].replace(seq,'')+'_info.pkl', 'rb'  ) as f:\n                    fuse_info = pickle.load(f )\n\n                image=cv2.imread( os.path.join(data_root,  info['image_path']).split('.jpg')[0].replace(seq,'')+'_rgb.jpg' )\n                try: \n                    depth_idx = fuse_info[2][seq_idx]['img_idx']\n                except:\n                    import pdb \n                    pdb.set_trace()\n\n                rendered_depth=np.load( os.path.dirname(os.path.join(depth_root, info['image_path']  ))+ f'/{depth_idx}_depth.png.npy'  ) \n\n                fuse_mask=cv2.imread( os.path.join(data_root,  info['image_path']).split('.jpg')[0].replace(seq,'')+'_mask.png', ) \n                fuse_mask = fuse_mask[...,0]==(seq_idx+1) # fuse mask id starts from 1\n\n                # \"\"\"\n                #may have bug\n                hs,ws=np.nonzero(rendered_depth<1)\n                hmin,hmax=np.min(hs),np.max(hs)\n                wmin,wmax=np.min(ws),np.max(ws)\n                # \"\"\"\n                bbox= [hmin+fuse_info[0][seq_idx][0], wmin+fuse_info[0][seq_idx][1], hmax+fuse_info[0][seq_idx][0], wmax+fuse_info[0][seq_idx][1]]\n\n                depth = np.ones_like(rendered_depth) \n\n                try:\n                    depth[hmin+fuse_info[0][seq_idx][0]: fuse_info[0][seq_idx][0]+hmax+1, wmin+fuse_info[0][seq_idx][1]: wmax+fuse_info[0][seq_idx][1] +1] = rendered_depth[hmin:hmax+1, wmin:wmax+1]\n                except:\n                    print(info['image_path'],\"failed!\") \n                    continue\n                    \"\"\"\n                    #TODO: temp fix, may fail\n                    patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_pat.jpg\")\n                    depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_pat_depth.npy\")\n                    pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_RT.pkl\")\n                    cv2.imwrite(patch_save_path, patch)\n                    np.save(depth_patch_save_path, depth_patch)\n                    with open(pose_save_path, 'wb+') as f:\n                        pickle.dump({\n                            # \"RT\": fuse_info[1][seq_idx], #info[\"RT\"],\n                            \"RT\": fuse_info_old[1][seq_idx], #info[\"RT\"],\n                            \"K\": K_new\n                        },f)\n                    # import pdb \n                    # pdb.set_trace()\n                    # instance_inds.append(len(instance_inds))\n                    continue\n                    \"\"\"\n                # fuse_info_old = copy.deepcopy(fuse_info)\n                \n                # K_old = info[\"K\"]\n                K_old = linemod_K.copy()\n                K_old[0,2] = (K_old[0,2]+fuse_info[0][seq_idx][1])\n                K_old[1,2] = (K_old[1,2]+fuse_info[0][seq_idx][0])\n\n                # maximum depth value = 1, which indicates the invalid regions \n                mask=depth<1\n                # transform the range map to depth map \n                depth = range_to_depth(mask, depth*2, K_old)\n\n                # depth = depth* fuse_mask + (1-fuse_mask) # use 1's to indicate the invalid depths\n                # depth = depth* fuse_mask  # use 0's to indicate the invalid depths\n                depth = depth # keep all the depths including occluded ones\n                \n\n                if self.crop_param is not None:\n                    # image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=0.1, output_size=128 )\n                    image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=self.crop_param['margin_ratio'], output_size=self.crop_param['output_size'] )\n                else:\n                    K_new = K_old\n\n                print(info['image_path'], info['depth_path'], bbox)\n                patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}.jpg\")\n                depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_depth.npy\")\n                # pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_RT.pkl\")\n                pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_params.pkl\")\n                mask_visb_save_path = os.path.join(save_root, seq, f\"{info['index']:05d}_mask_visb.png\")\n\n                if not os.path.exists(os.path.join(save_root, seq)):\n                    os.makedirs(os.path.join(save_root, seq))\n                #save\n                cv2.imwrite(patch_save_path,image)\n                cv2.imwrite(mask_visb_save_path, fuse_mask*255 )\n\n                np.save(depth_patch_save_path, depth)\n                with open(pose_save_path, 'wb+') as f:\n                    pickle.dump({\n                        \"RT\": fuse_info[1][seq_idx], #info[\"RT\"],\n                        \"K\": K_new,\n                        \"bbox\": bbox\n                    },f)\n    \n    def _proc_LM_FUSE_SINGLE_PVNET(self, data_info, data_root, depth_root, save_root):\n        \n        # The class name list used during the fusing process, which is used to find the respective mask index  \n        linemod_cls_names=['ape','cam','cat','duck','glue','iron','phone', 'benchvise','can','driller','eggbox','holepuncher','lamp']\n\n        for seq in data_info:\n            # seq_idx = linemod_cls_names.index(seq)\n            seq_idx = 0 \n            for idx in range(len(data_info[seq]) ):\n                # info = {\n                #     \"index\": idx,\n                #     \"image_path\": image_paths[idx].replace(image_path_dir+'/',''),\n                #     \"depth_path\": depth_paths[idx].replace(depth_path_dir+'/',''),\n                #     \"RT\": pose[\"RT\"],\n                #     \"K\": pose[\"K\"],\n                # }\n                info = data_info[seq][idx]\n                # if info['image_path'] =='cat/2744.jpg':\n                #     info =  data_info[seq][idx+1]\n                with open(os.path.join(data_root, info['image_path'].split('.jpg')[0]+'_info.pkl'  ), 'rb'  ) as f:\n                    fuse_info = pickle.load(f )\n\n                # image=cv2.imread( os.path.join(data_root,  seq, info['image_path'].split('.jpg')[0].replace(seq,'')+'_rgb.jpg' ) )\n                image=cv2.imread( os.path.join(data_root,  info['image_path'].split('.jpg')[0]+'_rgb.jpg' ) )\n                try: \n                    depth_idx = fuse_info[2][seq_idx]['img_idx']\n                except:\n                    import pdb \n                    pdb.set_trace()\n\n                rendered_depth=np.load( os.path.dirname(os.path.join(depth_root, info['image_path']  ))+ f'/{depth_idx}_depth.png.npy'  ) \n\n                fuse_mask=cv2.imread( os.path.join(data_root,  info['image_path']).split('.jpg')[0]+'_mask.png', ) \n                fuse_mask = fuse_mask[...,0]==(seq_idx+1) # fuse mask id starts from 1\n\n                # \"\"\"\n                #may have bug\n                hs,ws=np.nonzero(rendered_depth<1)\n                hmin,hmax=np.min(hs),np.max(hs)\n                wmin,wmax=np.min(ws),np.max(ws)\n                # \"\"\"\n                bbox= [hmin+fuse_info[0][seq_idx][0], wmin+fuse_info[0][seq_idx][1], hmax+fuse_info[0][seq_idx][0], wmax+fuse_info[0][seq_idx][1]]\n\n                depth = np.ones_like(rendered_depth) \n\n                try:\n                    depth[hmin+fuse_info[0][seq_idx][0]: fuse_info[0][seq_idx][0]+hmax+1, wmin+fuse_info[0][seq_idx][1]: wmax+fuse_info[0][seq_idx][1] +1] = rendered_depth[hmin:hmax+1, wmin:wmax+1]\n                except:\n                    import pdb \n                    pdb.set_trace()\n                    print(info['image_path'],\"failed!\") \n                    continue\n                    \"\"\"\n                    #TODO: temp fix, may fail\n                    patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_pat.jpg\")\n                    depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_pat_depth.npy\")\n                    pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_RT.pkl\")\n                    cv2.imwrite(patch_save_path, patch)\n                    np.save(depth_patch_save_path, depth_patch)\n                    with open(pose_save_path, 'wb+') as f:\n                        pickle.dump({\n                            # \"RT\": fuse_info[1][seq_idx], #info[\"RT\"],\n                            \"RT\": fuse_info_old[1][seq_idx], #info[\"RT\"],\n                            \"K\": K_new\n                        },f)\n                    # import pdb \n                    # pdb.set_trace()\n                    # instance_inds.append(len(instance_inds))\n                    continue\n                    \"\"\"\n                # fuse_info_old = copy.deepcopy(fuse_info)\n                \n                # K_old = info[\"K\"]\n                K_old = linemod_K.copy()\n                K_old[0,2] = (K_old[0,2]+fuse_info[0][seq_idx][1])\n                K_old[1,2] = (K_old[1,2]+fuse_info[0][seq_idx][0])\n\n                # maximum depth value = 1, which indicates the invalid regions \n                mask=depth<1\n                # transform the range map to depth map \n                depth = range_to_depth(mask, depth*2, K_old)\n\n                # depth = depth* fuse_mask + (1-fuse_mask) # use 1's to indicate the invalid depths\n                depth = depth* fuse_mask  # use 0's to indicate the invalid depths\n                \n\n                if self.crop_param is not None:\n                    # image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=0.1, output_size=128 )\n                    image, depth, K_new=crop(image, depth, mask, K_old, margin_ratio=self.crop_param['margin_ratio'], output_size=self.crop_param['output_size'] )\n                else:\n                    K_new = K_old\n\n                print(info['image_path'], info['depth_path'], bbox)\n                patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}.jpg\")\n                depth_patch_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_depth.npy\")\n                # pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_RT.pkl\")\n                pose_save_path=os.path.join(save_root, seq, f\"{info['index']:05d}_params.pkl\")\n\n                if not os.path.exists(os.path.join(save_root, seq)):\n                    os.makedirs(os.path.join(save_root, seq))\n                #save\n                cv2.imwrite(patch_save_path,image)\n                np.save(depth_patch_save_path, depth)\n                with open(pose_save_path, 'wb+') as f:\n                    pickle.dump({\n                        \"RT\": fuse_info[1][seq_idx], #info[\"RT\"],\n                        \"K\": K_new,\n                        \"bbox\": bbox\n                    },f)\n\n\ndef run(data_type,data_info_path, image_root, depth_root, save_dir, crop_param=None):\n    df = DataFormatter(data_type,data_info_path)\n    df.process(image_root, depth_root,save_dir)\n\nif __name__=='__main__':\n    fire.Fire(run)\n    \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "torchplus/__init__.py",
    "content": "from . import train\r\nfrom . import nn\r\nfrom . import metrics\r\nfrom . import tools\r\n\r\nfrom .tools import change_default_args\r\nfrom torchplus.ops.array_ops import scatter_nd, gather_nd, roll\r\n"
  },
  {
    "path": "torchplus/metrics.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass Scalar(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.register_buffer('total', torch.FloatTensor([0.0]))\n        self.register_buffer('count', torch.FloatTensor([0.0]))\n\n    def forward(self, scalar):\n        if not scalar.eq(0.0):\n            self.count += 1\n            self.total += scalar.data.float()\n        return self.value.cpu()\n\n    @property\n    def value(self):\n        return self.total / self.count\n\n    def clear(self):\n        self.total.zero_()\n        self.count.zero_()\n\nclass Accuracy(nn.Module):\n    def __init__(self,\n                 dim=1,\n                 ignore_idx=-1,\n                 threshold=0.5,\n                 encode_background_as_zeros=True):\n        super().__init__()\n        self.register_buffer('total', torch.FloatTensor([0.0]))\n        self.register_buffer('count', torch.FloatTensor([0.0]))\n        self._ignore_idx = ignore_idx\n        self._dim = dim\n        self._threshold = threshold\n        self._encode_background_as_zeros = encode_background_as_zeros\n\n    def forward(self, labels, preds, weights=None):\n        # labels: [N, ...]\n        # preds: [N, C, ...]\n        if self._encode_background_as_zeros:\n            scores = torch.sigmoid(preds)\n            labels_pred = torch.max(preds, dim=self._dim)[1] + 1\n            pred_labels = torch.where((scores > self._threshold).any(self._dim),\n                                      labels_pred,\n                                      torch.tensor(0).type_as(labels_pred))\n        else:\n            pred_labels = torch.max(preds, dim=self._dim)[1]\n        N, *Ds = labels.shape\n        labels = labels.view(N, int(np.prod(Ds)))\n        pred_labels = pred_labels.view(N, int(np.prod(Ds)))\n        if weights is None:\n            weights = (labels != self._ignore_idx).float()\n        else:\n            weights = weights.float()\n\n        num_examples = torch.sum(weights)\n        num_examples = torch.clamp(num_examples, min=1.0).float()\n        total = torch.sum((pred_labels == labels.long()).float())\n        self.count += num_examples\n        self.total += total\n        return self.value.cpu()\n        # return (total /  num_examples.data).cpu()\n    @property\n    def value(self):\n        return self.total / self.count\n\n    def clear(self):\n        self.total.zero_()\n        self.count.zero_()\n\n\nclass Precision(nn.Module):\n    def __init__(self, dim=1, ignore_idx=-1, threshold=0.5):\n        super().__init__()\n        self.register_buffer('total', torch.FloatTensor([0.0]))\n        self.register_buffer('count', torch.FloatTensor([0.0]))\n        self._ignore_idx = ignore_idx\n        self._dim = dim\n        self._threshold = threshold\n\n    def forward(self, labels, preds, weights=None):\n        # labels: [N, ...]\n        # preds: [N, C, ...]\n        if preds.shape[self._dim] == 1:  # BCE\n            pred_labels = (torch.sigmoid(preds) >\n                           self._threshold).long().squeeze(self._dim)\n        else:\n            assert preds.shape[\n                self._dim] == 2, \"precision only support 2 class\"\n            pred_labels = torch.max(preds, dim=self._dim)[1]\n        N, *Ds = labels.shape\n        labels = labels.view(N, int(np.prod(Ds)))\n        pred_labels = pred_labels.view(N, int(np.prod(Ds)))\n        if weights is None:\n            weights = (labels != self._ignore_idx).float()\n        else:\n            weights = weights.float()\n\n        pred_trues = pred_labels > 0\n        pred_falses = pred_labels == 0\n        trues = labels > 0\n        falses = labels == 0\n        true_positives = (weights * (trues & pred_trues).float()).sum()\n        true_negatives = (weights * (falses & pred_falses).float()).sum()\n        false_positives = (weights * (falses & pred_trues).float()).sum()\n        false_negatives = (weights * (trues & pred_falses).float()).sum()\n        count = true_positives + false_positives\n        # print(count, true_positives)\n        if count > 0:\n            self.count += count\n            self.total += true_positives\n        return self.value.cpu()\n        # return (total /  num_examples.data).cpu()\n    @property\n    def value(self):\n        return self.total / self.count\n    def clear(self):\n        self.total.zero_()\n        self.count.zero_()\n\n\nclass Recall(nn.Module):\n    def __init__(self, dim=1, ignore_idx=-1, threshold=0.5):\n        super().__init__()\n        self.register_buffer('total', torch.FloatTensor([0.0]))\n        self.register_buffer('count', torch.FloatTensor([0.0]))\n        self._ignore_idx = ignore_idx\n        self._dim = dim\n        self._threshold = threshold\n\n    def forward(self, labels, preds, weights=None):\n        # labels: [N, ...]\n        # preds: [N, C, ...]\n        if preds.shape[self._dim] == 1:  # BCE\n            pred_labels = (torch.sigmoid(preds) >\n                           self._threshold).long().squeeze(self._dim)\n        else:\n            assert preds.shape[\n                self._dim] == 2, \"precision only support 2 class\"\n            pred_labels = torch.max(preds, dim=self._dim)[1]\n        N, *Ds = labels.shape\n        labels = labels.view(N, int(np.prod(Ds)))\n        pred_labels = pred_labels.view(N, int(np.prod(Ds)))\n        if weights is None:\n            weights = (labels != self._ignore_idx).float()\n        else:\n            weights = weights.float()\n        pred_trues = pred_labels == 1\n        pred_falses = pred_labels == 0\n        trues = labels == 1\n        falses = labels == 0\n        true_positives = (weights * (trues & pred_trues).float()).sum()\n        true_negatives = (weights * (falses & pred_falses).float()).sum()\n        false_positives = (weights * (falses & pred_trues).float()).sum()\n        false_negatives = (weights * (trues & pred_falses).float()).sum()\n        count = true_positives + false_negatives\n        if count > 0:\n            self.count += count\n            self.total += true_positives\n        return self.value.cpu()\n        # return (total /  num_examples.data).cpu()\n    @property\n    def value(self):\n        return self.total / self.count\n    def clear(self):\n        self.total.zero_()\n        self.count.zero_()\n\n\ndef _calc_binary_metrics(labels,\n                         scores,\n                         weights=None,\n                         ignore_idx=-1,\n                         threshold=0.5):\n\n    pred_labels = (scores > threshold).long()\n    N, *Ds = labels.shape\n    labels = labels.view(N, int(np.prod(Ds)))\n    pred_labels = pred_labels.view(N, int(np.prod(Ds)))\n    pred_trues = pred_labels > 0\n    pred_falses = pred_labels == 0\n    trues = labels > 0\n    falses = labels == 0\n    true_positives = (weights * (trues & pred_trues).float()).sum()\n    true_negatives = (weights * (falses & pred_falses).float()).sum()\n    false_positives = (weights * (falses & pred_trues).float()).sum()\n    false_negatives = (weights * (trues & pred_falses).float()).sum()\n    return true_positives, true_negatives, false_positives, false_negatives\n\n\nclass PrecisionRecall(nn.Module):\n    def __init__(self,\n                 dim=1,\n                 ignore_idx=-1,\n                 thresholds=0.5,\n                 use_sigmoid_score=False,\n                 encode_background_as_zeros=True):\n        super().__init__()\n        if not isinstance(thresholds, (list, tuple)):\n            thresholds = [thresholds]\n\n        self.register_buffer('prec_total',\n                             torch.FloatTensor(len(thresholds)).zero_())\n        self.register_buffer('prec_count',\n                             torch.FloatTensor(len(thresholds)).zero_())\n        self.register_buffer('rec_total',\n                             torch.FloatTensor(len(thresholds)).zero_())\n        self.register_buffer('rec_count',\n                             torch.FloatTensor(len(thresholds)).zero_())\n\n        self._ignore_idx = ignore_idx\n        self._dim = dim\n        self._thresholds = thresholds\n        self._use_sigmoid_score = use_sigmoid_score\n        self._encode_background_as_zeros = encode_background_as_zeros\n\n    def forward(self, labels, preds, weights=None):\n        # labels: [N, ...]\n        # preds: [N, ..., C]\n        if self._encode_background_as_zeros:\n            # this don't support softmax\n            assert self._use_sigmoid_score is True\n            total_scores = torch.sigmoid(preds)\n            # scores, label_preds = torch.max(total_scores, dim=1)\n        else:\n            if self._use_sigmoid_score:\n                total_scores = torch.sigmoid(preds)[..., 1:]\n            else:\n                total_scores = F.softmax(preds, dim=-1)[..., 1:]\n        \"\"\"\n        if preds.shape[self._dim] == 1:  # BCE\n            scores = torch.sigmoid(preds)\n        else:\n            # assert preds.shape[\n            #     self._dim] == 2, \"precision only support 2 class\"\n            # TODO: add support for [N, C, ...] format.\n            # TODO: add multiclass support\n            if self._use_sigmoid_score:\n                scores = torch.sigmoid(preds)[:, ..., 1:].sum(-1)\n            else:\n                scores = F.softmax(preds, dim=self._dim)[:, ..., 1:].sum(-1)\n        \"\"\"\n        scores = torch.max(total_scores, dim=-1)[0]\n        if weights is None:\n            weights = (labels != self._ignore_idx).float()\n        else:\n            weights = weights.float()\n        for i, thresh in enumerate(self._thresholds):\n            tp, tn, fp, fn = _calc_binary_metrics(labels, scores, weights,\n                                                  self._ignore_idx, thresh)\n            rec_count = tp + fn\n            prec_count = tp + fp\n            if rec_count > 0:\n                self.rec_count[i] += rec_count\n                self.rec_total[i] += tp\n            if prec_count > 0:\n                self.prec_count[i] += prec_count\n                self.prec_total[i] += tp\n\n        return self.value\n        # return (total /  num_examples.data).cpu()\n    @property\n    def value(self):\n        prec_count = torch.clamp(self.prec_count, min=1.0)\n        rec_count = torch.clamp(self.rec_count, min=1.0)\n        return ((self.prec_total / prec_count).cpu(),\n                (self.rec_total / rec_count).cpu())\n\n    @property\n    def thresholds(self):\n        return self._thresholds\n\n    def clear(self):\n        self.rec_count.zero_()\n        self.prec_count.zero_()\n        self.prec_total.zero_()\n        self.rec_total.zero_()\n"
  },
  {
    "path": "torchplus/nn/__init__.py",
    "content": "from torchplus.nn.functional import one_hot\nfrom torchplus.nn.modules.common import Empty, Sequential\nfrom torchplus.nn.modules.normalization import GroupNorm\n"
  },
  {
    "path": "torchplus/nn/functional.py",
    "content": "import torch\n\ndef one_hot(tensor, depth, dim=-1, on_value=1.0, dtype=torch.float32):\n    tensor_onehot = torch.zeros(\n        *list(tensor.shape), depth, dtype=dtype, device=tensor.device)\n    tensor_onehot.scatter_(dim, tensor.unsqueeze(dim).long(), on_value)\n    return tensor_onehot\n"
  },
  {
    "path": "torchplus/nn/modules/__init__.py",
    "content": ""
  },
  {
    "path": "torchplus/nn/modules/common.py",
    "content": "import sys\nfrom collections import OrderedDict\n\nimport torch\nfrom torch.nn import functional as F\n\n\nclass Empty(torch.nn.Module):\n    def __init__(self, *args, **kwargs):\n        super(Empty, self).__init__()\n        self.weight = torch.zeros([1, ])  # dummy varaible\n\n    def forward(self, *args, **kwargs):\n        if len(args) == 1:\n            return args[0]\n        elif len(args) == 0:\n            return None\n        return args\n\n\nclass Sequential(torch.nn.Module):\n    r\"\"\"A sequential container.\n    Modules will be added to it in the order they are passed in the constructor.\n    Alternatively, an ordered dict of modules can also be passed in.\n\n    To make it easier to understand, given is a small example::\n\n        # Example of using Sequential\n        model = Sequential(\n                  nn.Conv2d(1,20,5),\n                  nn.ReLU(),\n                  nn.Conv2d(20,64,5),\n                  nn.ReLU()\n                )\n\n        # Example of using Sequential with OrderedDict\n        model = Sequential(OrderedDict([\n                  ('conv1', nn.Conv2d(1,20,5)),\n                  ('relu1', nn.ReLU()),\n                  ('conv2', nn.Conv2d(20,64,5)),\n                  ('relu2', nn.ReLU())\n                ]))\n\n        # Example of using Sequential with kwargs(python 3.6+)\n        model = Sequential(\n                  conv1=nn.Conv2d(1,20,5),\n                  relu1=nn.ReLU(),\n                  conv2=nn.Conv2d(20,64,5),\n                  relu2=nn.ReLU()\n                )\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(Sequential, self).__init__()\n        if len(args) == 1 and isinstance(args[0], OrderedDict):\n            for key, module in args[0].items():\n                self.add_module(key, module)\n        else:\n            for idx, module in enumerate(args):\n                self.add_module(str(idx), module)\n        for name, module in kwargs.items():\n            if sys.version_info < (3, 6):\n                raise ValueError(\"kwargs only supported in py36+\")\n            if name in self._modules:\n                raise ValueError(\"name exists.\")\n            self.add_module(name, module)\n\n    def __getitem__(self, idx):\n        if not (-len(self) <= idx < len(self)):\n            raise IndexError('index {} is out of range'.format(idx))\n        if idx < 0:\n            idx += len(self)\n        it = iter(self._modules.values())\n        for i in range(idx):\n            next(it)\n        return next(it)\n\n    def __len__(self):\n        return len(self._modules)\n\n    def add(self, module, name=None):\n        if name is None:\n            name = str(len(self._modules))\n            if name in self._modules:\n                raise KeyError(\"name exists\")\n        self.add_module(name, module)\n\n    def forward(self, input):\n        # i = 0\n        for module in self._modules.values():\n            # print(i)\n            input = module(input)\n            # i += 1\n        return input\n"
  },
  {
    "path": "torchplus/nn/modules/normalization.py",
    "content": "import torch\n\n\nclass GroupNorm(torch.nn.GroupNorm):\n    def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):\n        super().__init__(\n            num_groups=num_groups,\n            num_channels=num_channels,\n            eps=eps,\n            affine=affine)\n"
  },
  {
    "path": "torchplus/ops/__init__.py",
    "content": ""
  },
  {
    "path": "torchplus/ops/array_ops.py",
    "content": "import ctypes\nimport math\nimport time\nimport torch\nfrom typing import Optional\n\n\ndef scatter_nd(indices, updates, shape):\n    \"\"\"pytorch edition of tensorflow scatter_nd.\n    this function don't contain except handle code. so use this carefully\n    when indice repeats, don't support repeat add which is supported\n    in tensorflow.\n    \"\"\"\n    ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)\n    ndim = indices.shape[-1]\n    output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]\n    flatted_indices = indices.view(-1, ndim)\n    slices = [flatted_indices[:, i] for i in range(ndim)]\n    slices += [Ellipsis]\n    ret[slices] = updates.view(*output_shape)\n    return ret\n\n\ndef gather_nd(params, indices):\n    # this function has a limit that MAX_ADVINDEX_CALC_DIMS=5\n    ndim = indices.shape[-1]\n    output_shape = list(indices.shape[:-1]) + list(params.shape[indices.shape[-1]:])\n    flatted_indices = indices.view(-1, ndim)\n    slices = [flatted_indices[:, i] for i in range(ndim)]\n    slices += [Ellipsis]\n    return params[slices].view(*output_shape)\n\n\ndef roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[int] = None):\n    \"\"\"\n        shift<0: left roll\n        shift>0: right roll\n    \"\"\"  \n    device = x.device\n    \n    if 0 == shift:\n        return x\n\n    elif shift < 0:\n        shift = -shift\n        gap = x.index_select(dim, torch.arange(shift, device=device))\n        if fill_pad is not None:\n            gap = fill_pad * torch.ones_like(gap, device=device)\n        return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim), device=device)), gap], dim=dim)\n\n    else:\n        shift = x.size(dim) - shift\n        gap = x.index_select(dim, torch.arange(shift, x.size(dim), device=device))\n        if fill_pad is not None:\n            gap = fill_pad * torch.ones_like(gap, device=device)\n        return torch.cat([gap, x.index_select(dim, torch.arange(shift, device=device))], dim=dim) "
  },
  {
    "path": "torchplus/tools.py",
    "content": "import functools\nimport inspect\nimport sys\nfrom collections import OrderedDict\n\nimport numba\nimport numpy as np\nimport torch\n\n\ndef get_pos_to_kw_map(func):\n    pos_to_kw = {}\n    fsig = inspect.signature(func)\n    pos = 0\n    for name, info in fsig.parameters.items():\n        if info.kind is info.POSITIONAL_OR_KEYWORD:\n            pos_to_kw[pos] = name\n        pos += 1\n    return pos_to_kw\n\n\ndef get_kw_to_default_map(func):\n    kw_to_default = {}\n    fsig = inspect.signature(func)\n    for name, info in fsig.parameters.items():\n        if info.kind is info.POSITIONAL_OR_KEYWORD:\n            if info.default is not info.empty:\n                kw_to_default[name] = info.default\n    return kw_to_default\n\n\n# def change_default_args(**kwargs):\n#     def layer_wrapper(layer_class):\n#         class DefaultArgLayer(layer_class):\n#             def __init__(self, *args, **kw):\n#                 pos_to_kw = get_pos_to_kw_map(layer_class.__init__)\n#                 kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}\n#                 for key, val in kwargs.items():\n#                     if key not in kw and kw_to_pos[key] > len(args):\n#                         kw[key] = val\n#                 super().__init__(*args, **kw)\n\n#         return DefaultArgLayer\n\n#     return layer_wrapper\n\ndef change_default_args(**kwargs):\n    def layer_wrapper(layer_class):\n        class DefaultArgLayer(layer_class):\n            def __init__(self, *args, **kw):\n                pos_to_kw = get_pos_to_kw_map(layer_class.__init__)\n                kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}\n                for key, val in kwargs.items():\n                    if key not in kw and kw_to_pos[key] > len(args):\n                        kw[key] = val\n                super(DefaultArgLayer,self).__init__(*args, **kw)\n\n        return DefaultArgLayer\n\n    return layer_wrapper\ndef torch_to_np_dtype(ttype):\n    type_map = {\n        torch.float16: np.dtype(np.float16),\n        torch.float32: np.dtype(np.float32),\n        torch.float16: np.dtype(np.float64),\n        torch.int32: np.dtype(np.int32),\n        torch.int64: np.dtype(np.int64),\n        torch.uint8: np.dtype(np.uint8),\n    }\n    return type_map[ttype]\n"
  },
  {
    "path": "torchplus/train/__init__.py",
    "content": "from torchplus.train.checkpoint import (latest_checkpoint, restore,\n                                        restore_latest_checkpoints,\n                                        restore_models, save, save_models,\n                                        try_restore_latest_checkpoints,\n                                        save_models_cpu\n                                        )\nfrom torchplus.train.common import create_folder\nfrom torchplus.train.optim import MixedPrecisionWrapper\n"
  },
  {
    "path": "torchplus/train/checkpoint.py",
    "content": "import json\nimport logging\nimport os\nimport signal\nfrom pathlib import Path\n\nimport torch\n\n\nclass DelayedKeyboardInterrupt(object):\n    def __enter__(self):\n        self.signal_received = False\n        self.old_handler = signal.signal(signal.SIGINT, self.handler)\n\n    def handler(self, sig, frame):\n        self.signal_received = (sig, frame)\n        logging.debug('SIGINT received. Delaying KeyboardInterrupt.')\n\n    def __exit__(self, type, value, traceback):\n        signal.signal(signal.SIGINT, self.old_handler)\n        if self.signal_received:\n            self.old_handler(*self.signal_received)\n\n\ndef latest_checkpoint(model_dir, model_name):\n    \"\"\"return path of latest checkpoint in a model_dir\n    Args:\n        model_dir: string, indicate your model dir(save ckpts, summarys,\n            logs, etc).\n        model_name: name of your model. we find ckpts by name\n    Returns:\n        path: None if isn't exist or latest checkpoint path.\n    \"\"\"\n    ckpt_info_path = Path(model_dir) / \"checkpoints.json\"\n    if not ckpt_info_path.is_file():\n        return None\n    with open(ckpt_info_path, 'r') as f:\n        ckpt_dict = json.loads(f.read())\n    if model_name not in ckpt_dict['latest_ckpt']:\n        return None\n    latest_ckpt = ckpt_dict['latest_ckpt'][model_name]\n    ckpt_file_name = Path(model_dir) / latest_ckpt\n    if not ckpt_file_name.is_file():\n        return None\n\n    return str(ckpt_file_name)\n\n\ndef _ordered_unique(seq):\n    seen = set()\n    return [x for x in seq if not (x in seen or seen.add(x))]\n\n\ndef save(model_dir,\n         model,\n         model_name,\n         global_step,\n         max_to_keep=8,\n         keep_latest=True):\n    \"\"\"save a model into model_dir.\n    Args:\n        model_dir: string, indicate your model dir(save ckpts, summarys,\n            logs, etc).\n        model: torch.nn.Module instance.\n        model_name: name of your model. we find ckpts by name\n        global_step: int, indicate current global step.\n        max_to_keep: int, maximum checkpoints to keep.\n        keep_latest: bool, if True and there are too much ckpts, \n            will delete oldest ckpt. else will delete ckpt which has\n            smallest global step.\n    Returns:\n        path: None if isn't exist or latest checkpoint path.\n    \"\"\"\n\n    # prevent save incomplete checkpoint due to key interrupt\n    with DelayedKeyboardInterrupt():\n        ckpt_info_path = Path(model_dir) / \"checkpoints.json\"\n        ckpt_filename = \"{}-{}.tckpt\".format(model_name, global_step)\n        ckpt_path = Path(model_dir) / ckpt_filename\n        if not ckpt_info_path.is_file():\n            ckpt_info_dict = {'latest_ckpt': {}, 'all_ckpts': {}}\n        else:\n            with open(ckpt_info_path, 'r') as f:\n                ckpt_info_dict = json.loads(f.read())\n        ckpt_info_dict['latest_ckpt'][model_name] = ckpt_filename\n        if model_name in ckpt_info_dict['all_ckpts']:\n            ckpt_info_dict['all_ckpts'][model_name].append(ckpt_filename)\n        else:\n            ckpt_info_dict['all_ckpts'][model_name] = [ckpt_filename]\n        all_ckpts = ckpt_info_dict['all_ckpts'][model_name]\n\n        torch.save(model.state_dict(), ckpt_path)\n        # check ckpt in all_ckpts is exist, if not, delete it from all_ckpts\n        all_ckpts_checked = []\n        for ckpt in all_ckpts:\n            ckpt_path_uncheck = Path(model_dir) / ckpt\n            if ckpt_path_uncheck.is_file():\n                all_ckpts_checked.append(str(ckpt_path_uncheck))\n        all_ckpts = all_ckpts_checked\n        if len(all_ckpts) > max_to_keep:\n           \n            if keep_latest:\n                ckpt_to_delete = all_ckpts.pop(0)\n            else:\n                # delete smallest step\n                def get_step(name): return int(\n                    name.split('.')[0].split('-')[1])\n                min_step = min([get_step(name) for name in all_ckpts])\n                ckpt_to_delete = \"{}-{}.tckpt\".format(model_name, min_step)\n                all_ckpts.remove(ckpt_to_delete)\n            os.remove(str(Path(model_dir) / ckpt_to_delete))\n        all_ckpts_filename = _ordered_unique([Path(f).name for f in all_ckpts])\n        ckpt_info_dict['all_ckpts'][model_name] = all_ckpts_filename\n        with open(ckpt_info_path, 'w') as f:\n            f.write(json.dumps(ckpt_info_dict, indent=2))\n\n\ndef restore(ckpt_path, model, map_func=None, map_location='cpu'):\n    if not Path(ckpt_path).is_file():\n        raise ValueError(\"checkpoint {} not exist.\".format(ckpt_path))\n    state_dict = torch.load(ckpt_path, map_location=map_location)\n    if map_func is not None:\n        # map_func(state_dict)\n        state_dict = map_func(state_dict) #modified @01/01/2020\n    model.load_state_dict(state_dict)\n    print(\"Restoring parameters from {}\".format(ckpt_path))\n\n\ndef _check_model_names(models):\n    model_names = []\n    for model in models:\n        if not hasattr(model, \"name\"):\n            raise ValueError(\"models must have name attr\")\n        model_names.append(model.name)\n    if len(model_names) != len(set(model_names)):\n        raise ValueError(\"models must have unique name: {}\".format(\n            \", \".join(model_names)))\n\n\ndef _get_name_to_model_map(models):\n    if isinstance(models, dict):\n        name_to_model = {name: m for name, m in models.items()}\n    else:\n        _check_model_names(models)\n        name_to_model = {m.name: m for m in models}\n    return name_to_model\n\n\ndef try_restore_latest_checkpoints(model_dir, models, map_func=None, map_location='cpu'):\n    name_to_model = _get_name_to_model_map(models)\n    for name, model in name_to_model.items():\n        latest_ckpt = latest_checkpoint(model_dir, name)\n        if latest_ckpt is not None:\n            restore(latest_ckpt, model, map_func, map_location)\n\n\ndef restore_latest_checkpoints(model_dir, models, map_func=None,  map_location='cpu'):\n    name_to_model = _get_name_to_model_map(models)\n    for name, model in name_to_model.items():\n        latest_ckpt = latest_checkpoint(model_dir, name)\n        if latest_ckpt is not None:\n            restore(latest_ckpt, model, map_func,map_location)\n        else:\n            raise ValueError(\"model {}\\'s ckpt isn't exist\".format(name))\n\n\ndef restore_models(model_dir, models, global_step, map_func=None, map_location='cpu'):\n    name_to_model = _get_name_to_model_map(models)\n    for name, model in name_to_model.items():\n        ckpt_filename = \"{}-{}.tckpt\".format(name, global_step)\n        ckpt_path = model_dir + \"/\" + ckpt_filename\n        restore(ckpt_path, model, map_func, map_location)\n\n\ndef save_models(model_dir,\n                models,\n                global_step,\n                max_to_keep=15,\n                keep_latest=True):\n    with DelayedKeyboardInterrupt():\n        name_to_model = _get_name_to_model_map(models)\n        for name, model in name_to_model.items():\n            save(model_dir, model, name, global_step, max_to_keep, keep_latest)\n\n\ndef gpu_to_cpu(models):\n    if len(models) == 1:\n        models[0].cpu()\n        return\n    for state in models[1].state.values():\n        for k, v in state.items():\n            if isinstance(v, torch.Tensor):\n                state[k] = v.cpu()\n    models[0].cpu()\n\n\ndef cpu_to_gpu(models):\n    if len(models) == 1:\n        models[0].cuda()\n        return\n    for state in models[1].state.values():\n        for k, v in state.items():\n            if torch.is_tensor(v):\n                state[k] = v.cuda()\n    models[0].cuda()\n\n\ndef save_models_cpu(model_dir,\n                    models,\n                    global_step,\n                    max_to_keep=15,\n                    keep_latest=True):\n    with DelayedKeyboardInterrupt():\n        name_to_model = _get_name_to_model_map(models)\n        gpu_to_cpu(models)\n        for name, model in name_to_model.items():\n            save(model_dir, model, name, global_step, max_to_keep, keep_latest)\n        cpu_to_gpu(models)\n"
  },
  {
    "path": "torchplus/train/common.py",
    "content": "import datetime\nimport os\nimport shutil\n\ndef create_folder(prefix, add_time=True, add_str=None, delete=False):\n    additional_str = ''\n    if delete is True:\n        if os.path.exists(prefix):\n            shutil.rmtree(prefix)\n        os.makedirs(prefix)\n    folder = prefix\n    if add_time is True:\n        # additional_str has a form such as '170903_220351'\n        additional_str += datetime.datetime.now().strftime(\"%y%m%d_%H%M%S\")\n        if add_str is not None:\n            folder += '/' + additional_str + '_' + add_str\n        else:\n            folder += '/' + additional_str\n    if delete is True:\n        if os.path.exists(folder):\n            shutil.rmtree(folder)\n    os.makedirs(folder)\n    return folder"
  },
  {
    "path": "torchplus/train/fastai_optim.py",
    "content": "from collections import Iterable, defaultdict\nfrom copy import deepcopy\nfrom itertools import chain\n\nimport torch\nfrom torch import nn\nfrom torch._utils import _unflatten_dense_tensors\nfrom torch.autograd import Variable\nfrom torch.nn.utils import parameters_to_vector\n\nbn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, )\n\n\ndef split_bn_bias(layer_groups):\n    \"Split the layers in `layer_groups` into batchnorm (`bn_types`) and non-batchnorm groups.\"\n    split_groups = []\n    for l in layer_groups:\n        l1, l2 = [], []\n        for c in l.children():\n            if isinstance(c, bn_types):\n                l2.append(c)\n            else:\n                l1.append(c)\n        split_groups += [nn.Sequential(*l1), nn.Sequential(*l2)]\n    return split_groups\n\n\ndef get_master(layer_groups, flat_master: bool = False):\n    \"Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32.\"\n    split_groups = split_bn_bias(layer_groups)\n    model_params = [[\n        param for param in lg.parameters() if param.requires_grad\n    ] for lg in split_groups]\n    if flat_master:\n        master_params = []\n        for lg in model_params:\n            if len(lg) != 0:\n                mp = parameters_to_vector([param.data.float() for param in lg])\n                mp = torch.nn.Parameter(mp, requires_grad=True)\n                if mp.grad is None:\n                    mp.grad = mp.new(*mp.size())\n                master_params.append([mp])\n            else:\n                master_params.append([])\n        return model_params, master_params\n    else:\n        master_params = [[param.clone().float().detach() for param in lg]\n                         for lg in model_params]\n        for mp in master_params:\n            for param in mp:\n                param.requires_grad = True\n        return model_params, master_params\n\n\ndef model_g2master_g(model_params, master_params,\n                     flat_master: bool = False) -> None:\n    \"Copy the `model_params` gradients to `master_params` for the optimizer step.\"\n    if flat_master:\n        for model_group, master_group in zip(model_params, master_params):\n            if len(master_group) != 0:\n                master_group[0].grad.data.copy_(\n                    parameters_to_vector(\n                        [p.grad.data.float() for p in model_group]))\n    else:\n        for model_group, master_group in zip(model_params, master_params):\n            for model, master in zip(model_group, master_group):\n                if model.grad is not None:\n                    if master.grad is None:\n                        master.grad = master.data.new(*master.data.size())\n                    master.grad.data.copy_(model.grad.data)\n                else:\n                    master.grad = None\n\n\ndef master2model(model_params, master_params,\n                 flat_master: bool = False) -> None:\n    \"Copy `master_params` to `model_params`.\"\n    if flat_master:\n        for model_group, master_group in zip(model_params, master_params):\n            if len(model_group) != 0:\n                for model, master in zip(\n                        model_group,\n                        _unflatten_dense_tensors(master_group[0].data,\n                                                 model_group)):\n                    model.data.copy_(master)\n    else:\n        for model_group, master_group in zip(model_params, master_params):\n            for model, master in zip(model_group, master_group):\n                model.data.copy_(master.data)\n\n\ndef listify(p=None, q=None):\n    \"Make `p` listy and the same length as `q`.\"\n    if p is None:\n        p = []\n    elif isinstance(p, str):\n        p = [p]\n    elif not isinstance(p, Iterable):\n        p = [p]\n    n = q if type(q) == int else len(p) if q is None else len(q)\n    if len(p) == 1:\n        p = p * n\n    assert len(p) == n, f'List len mismatch ({len(p)} vs {n})'\n    return list(p)\n\n\ndef trainable_params(m: nn.Module):\n    \"Return list of trainable params in `m`.\"\n\n    res = filter(lambda p: p.requires_grad, m.parameters())\n    return res\n\n\ndef is_tuple(x) -> bool:\n    return isinstance(x, tuple)\n\n\n# copy from fastai.\nclass OptimWrapper(torch.optim.Optimizer):\n    \"Basic wrapper around `opt` to simplify hyper-parameters changes.\"\n\n    def __init__(self, opt, wd, true_wd: bool = False, bn_wd: bool = True):\n        # super().__init__(opt.param_groups, dict())\n        self.opt, self.true_wd, self.bn_wd = opt, true_wd, bn_wd\n        self.opt_keys = list(self.opt.param_groups[0].keys())\n        self.opt_keys.remove('params')\n        self.read_defaults()\n        self.wd = wd\n        self.param_segs=[]\n    @classmethod\n    def create(cls, opt_func, lr, layer_groups, **kwargs):\n        \"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`.\"\n\n        # param_segs=[]\n        if len(layer_groups)==1:\n            split_groups = split_bn_bias(layer_groups) #non-bn and bns\n        else:\n            split_groups = []\n            for lg in layer_groups:\n                split_groups+=split_bn_bias(lg)\n                # param_segs.append(len(lg ))\n            #\n            # buf = layer_groups[0] \n            # for i in range(1,len(layer_groups)):\n            #     buf+=layer_groups[i]\n\n            # layer_groups =buf # nn.Sequential(*buf)\n                \n        opt = opt_func([{\n            'params': trainable_params(l),\n            'lr': 0\n        } for l in split_groups])\n        # import pdb \n        # pdb.set_trace()\n        opt = cls(opt, **kwargs)\n        opt.lr, opt.opt_func = listify(lr, layer_groups), opt_func\n        \n        # opt.param_segs=param_segs\n        return opt\n\n    def new(self, layer_groups):\n        \"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters.\"\n        opt_func = getattr(self, 'opt_func', self.opt.__class__)\n        split_groups = split_bn_bias(layer_groups)\n        opt = opt_func([{\n            'params': trainable_params(l),\n            'lr': 0\n        } for l in split_groups])\n        return self.create(\n            opt_func,\n            self.lr,\n            layer_groups,\n            wd=self.wd,\n            true_wd=self.true_wd,\n            bn_wd=self.bn_wd)\n\n    def __repr__(self) -> str:\n        return f'OptimWrapper over {repr(self.opt)}.\\nTrue weight decay: {self.true_wd}'\n\n    # Pytorch optimizer methods\n    def step(self) -> None:\n        \"Set weight decay and step optimizer.\"\n        # weight decay outside of optimizer step (AdamW)\n        if self.true_wd:\n            for lr, wd, pg1, pg2 in zip(self._lr, self._wd,\n                                        self.opt.param_groups[::2],\n                                        self.opt.param_groups[1::2]):\n                for p in pg1['params']:\n                    p.data.mul_(1 - wd * lr)\n                if self.bn_wd:\n                    for p in pg2['params']:\n                        p.data.mul_(1 - wd * lr)\n            self.set_val('weight_decay', listify(0, self._wd))\n        self.opt.step()\n\n    def zero_grad(self) -> None:\n        \"Clear optimizer gradients.\"\n        self.opt.zero_grad()\n\n    # Passthrough to the inner opt.\n    def __getstate__(self):\n        return self.opt.__getstate__()\n\n    def __setstate__(self, state):\n        return self.opt.__setstate__(state)\n\n    def state_dict(self):\n        return self.opt.state_dict()\n\n    def load_state_dict(self, state_dict):\n        return self.opt.load_state_dict(state_dict)\n\n    def add_param_group(self, param_group):\n        return self.opt.add_param_group(param_group)\n\n    def clear(self):\n        \"Reset the state of the inner optimizer.\"\n        sd = self.state_dict()\n        sd['state'] = {}\n        self.load_state_dict(sd)\n\n    @property\n    def param_groups(self):\n        return self.opt.param_groups\n\n    @property\n    def defaults(self):\n        return self.opt.defaults\n\n    @property\n    def state(self):\n        return self.opt.state\n\n    # Hyperparameters as properties\n    @property\n    def lr(self) -> float:\n        return self._lr[-1]\n\n    @lr.setter\n    def lr(self, val: float) -> None:\n        self._lr = self.set_val('lr', listify(val, self._lr))\n\n    @property\n    def mom(self) -> float:\n        return self._mom[-1]\n\n    @mom.setter\n    def mom(self, val: float) -> None:\n        if 'momentum' in self.opt_keys:\n            self.set_val('momentum', listify(val, self._mom))\n        elif 'betas' in self.opt_keys:\n            self.set_val('betas', (listify(val, self._mom), self._beta))\n        self._mom = listify(val, self._mom)\n\n    @property\n    def beta(self) -> float:\n        return None if self._beta is None else self._beta[-1]\n\n    @beta.setter\n    def beta(self, val: float) -> None:\n        \"Set beta (or alpha as makes sense for given optimizer).\"\n        if val is None:\n            return\n        if 'betas' in self.opt_keys:\n            self.set_val('betas', (self._mom, listify(val, self._beta)))\n        elif 'alpha' in self.opt_keys:\n            self.set_val('alpha', listify(val, self._beta))\n        self._beta = listify(val, self._beta)\n\n    @property\n    def wd(self) -> float:\n        return self._wd[-1]\n\n    @wd.setter\n    def wd(self, val: float) -> None:\n        \"Set weight decay.\"\n        if not self.true_wd:\n            self.set_val(\n                'weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)\n        self._wd = listify(val, self._wd)\n\n    # Helper functions\n    def read_defaults(self) -> None:\n        \"Read the values inside the optimizer for the hyper-parameters.\"\n        self._beta = None\n        if 'lr' in self.opt_keys:\n            self._lr = self.read_val('lr')\n        if 'momentum' in self.opt_keys:\n            self._mom = self.read_val('momentum')\n        if 'alpha' in self.opt_keys:\n            self._beta = self.read_val('alpha')\n        if 'betas' in self.opt_keys:\n            self._mom, self._beta = self.read_val('betas')\n        if 'weight_decay' in self.opt_keys:\n            self._wd = self.read_val('weight_decay')\n\n    def set_val(self, key: str, val, bn_groups: bool = True):\n        \"Set `val` inside the optimizer dictionary at `key`.\"\n        if is_tuple(val):\n            val = [(v1, v2) for v1, v2 in zip(*val)]\n        for v, pg1, pg2 in zip(val, self.opt.param_groups[::2],\n                               self.opt.param_groups[1::2]):\n            pg1[key] = v\n            if bn_groups:\n                pg2[key] = v\n        return val\n\n    def read_val(self, key: str):\n        \"Read a hyperparameter `key` in the optimizer dictionary.\"\n        val = [pg[key] for pg in self.opt.param_groups[::2]]\n        if is_tuple(val[0]):\n            val = [o[0] for o in val], [o[1] for o in val]\n        return val\n\n\nclass FastAIMixedOptim(OptimWrapper):\n    @classmethod\n    def create(cls,\n               opt_func,\n               lr,\n               layer_groups,\n               model,\n               flat_master=False,\n               loss_scale=512.0,\n               **kwargs):\n        \"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`.\"\n        opt = OptimWrapper.create(opt_func, lr, layer_groups, **kwargs)\n        opt.model_params, opt.master_params = get_master(\n            layer_groups, flat_master)\n        opt.flat_master = flat_master\n        opt.loss_scale = loss_scale\n        opt.model = model\n        # Changes the optimizer so that the optimization step is done in FP32.\n        # opt = self.learn.opt\n        mom, wd, beta = opt.mom, opt.wd, opt.beta\n        lrs = [lr for lr in opt._lr for _ in range(2)]\n        opt_params = [{\n            'params': mp,\n            'lr': lr\n        } for mp, lr in zip(opt.master_params, lrs)]\n        opt.opt = opt_func(opt_params)\n        opt.mom, opt.wd, opt.beta = mom, wd, beta\n        return opt\n\n    def step(self):\n        model_g2master_g(self.model_params, self.master_params,\n                         self.flat_master)\n        for group in self.master_params:\n            for param in group:\n                param.grad.div_(self.loss_scale)\n        super(FastAIMixedOptim, self).step()\n        self.model.zero_grad()\n        # Update the params from master to model.\n        master2model(self.model_params, self.master_params, self.flat_master)\n"
  },
  {
    "path": "torchplus/train/learning_schedules.py",
    "content": "\"\"\"PyTorch edition of TensorFlow learning schedule in tensorflow object\ndetection API. \n\"\"\"\nimport numpy as np\nfrom torch.optim.optimizer import Optimizer\nclass _LRSchedulerStep(object):\n    def __init__(self, optimizer, last_step=-1):\n        if not isinstance(optimizer, Optimizer):\n            raise TypeError('{} is not an Optimizer'.format(\n                type(optimizer).__name__))\n        self.optimizer = optimizer\n        if last_step == -1:\n            for group in optimizer.param_groups:\n                group.setdefault('initial_lr', group['lr'])\n        else:\n            for i, group in enumerate(optimizer.param_groups):\n                if 'initial_lr' not in group:\n                    raise KeyError(\n                        \"param 'initial_lr' is not specified \"\n                        \"in param_groups[{}] when resuming an optimizer\".\n                        format(i))\n        self.base_lrs = list(\n            map(lambda group: group['initial_lr'], optimizer.param_groups))\n        self.step(last_step + 1)\n        self.last_step = last_step\n\n    \"\"\"\n    def get_lr(self):\n        raise NotImplementedError\n    \"\"\"\n\n    def get_lr(self):\n        ret = [self._get_lr_per_group(base_lr) for base_lr in self.base_lrs]\n        return ret\n\n    def _get_lr_per_group(self, base_lr):\n        raise NotImplementedError\n\n    def step(self, step=None):\n        if step is None:\n            step = self.last_step + 1\n        self.last_step = step\n        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):\n            param_group['lr'] = lr\n\n\nclass Constant(_LRSchedulerStep):\n    def __init__(self, optimizer, last_step=-1):\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        return base_lr\n\n\nclass ManualStepping(_LRSchedulerStep):\n    \"\"\"Pytorch edition of manual_stepping in tensorflow.\n    DON'T SUPPORT PARAM GROUPS.\n    \"\"\"\n\n    def __init__(self, optimizer, boundaries, rates, last_step=-1):\n        self._boundaries = boundaries\n        self._num_boundaries = len(boundaries)\n        self._learning_rates = rates\n\n        if any([b < 0 for b in boundaries]) or any(\n            [not isinstance(b, int) for b in boundaries]):\n            raise ValueError('boundaries must be a list of positive integers')\n        if any(\n            [bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):\n            raise ValueError(\n                'Entries in boundaries must be strictly increasing.')\n        if any([not isinstance(r, float) for r in rates]):\n            raise ValueError('Learning rates must be floats')\n        if len(rates) != len(boundaries) + 1:\n            raise ValueError('Number of provided learning rates must exceed '\n                             'number of boundary points by exactly 1.')\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        step = self.last_step\n        ret = None\n        for i, bound in enumerate(self._boundaries):\n            if step > bound:\n                ret = self._learning_rates[i + 1]\n        if ret is not None:\n            return ret\n        return self._learning_rates[0]\n\n\nclass ExponentialDecayWithBurnin(_LRSchedulerStep):\n    \"\"\"Pytorch edition of manual_stepping in tensorflow.\n    \"\"\"\n\n    def __init__(self,\n                 optimizer,\n                 learning_rate_decay_steps,\n                 learning_rate_decay_factor,\n                 burnin_learning_rate,\n                 burnin_steps,\n                 last_step=-1):\n        self._decay_steps = learning_rate_decay_steps\n        self._decay_factor = learning_rate_decay_factor\n        self._burnin_learning_rate = burnin_learning_rate\n        self._burnin_steps = burnin_steps\n\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        if self._burnin_learning_rate == 0:\n            burnin_learning_rate = base_lr\n        step = self.last_step\n        post_burnin_learning_rate = (base_lr * self._decay_factor ^\n                                     (step // self._decay_steps))\n        if step < self._burnin_steps:\n            return burnin_learning_rate\n        else:\n            return post_burnin_learning_rate\n\n\nclass ExponentialDecay(_LRSchedulerStep):\n    def __init__(self,\n                 optimizer,\n                 learning_rate_decay_steps,\n                 learning_rate_decay_factor,\n                 staircase=True,\n                 last_step=-1):\n        self._decay_steps = learning_rate_decay_steps\n        self._decay_factor = learning_rate_decay_factor\n        self._staircase = staircase\n\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        step = self.last_step\n        if self._staircase:\n            post_burnin_learning_rate = base_lr * pow(self._decay_factor,\n                                         (step // self._decay_steps))\n        else:\n            post_burnin_learning_rate = base_lr * pow(self._decay_factor,\n                                         (step / self._decay_steps))\n\n        return post_burnin_learning_rate\n\n\nclass CosineDecayWithWarmup(_LRSchedulerStep):\n    def __init__(self,\n                 optimizer,\n                 total_steps,\n                 warmup_learning_rate,\n                 warmup_steps,\n                 last_step=-1):\n        if total_steps < warmup_steps:\n            raise ValueError('total_steps must be larger or equal to '\n                             'warmup_steps.')\n        self._total_steps = total_steps\n        self._warmup_learning_rate = warmup_learning_rate\n        self._warmup_steps = warmup_steps\n\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        if base_lr < self._warmup_learning_rate:\n            raise ValueError('learning_rate_base must be larger '\n                             'or equal to warmup_learning_rate.')\n\n        step = self.last_step\n        learning_rate = 0.5 * base_lr * (\n            1 + np.cos(np.pi *\n                       (float(step) - self._warmup_steps\n                        ) / float(self._total_steps - self._warmup_steps)))\n        if self._warmup_steps > 0:\n            slope = (base_lr - self._warmup_learning_rate) / self._warmup_steps\n            pre_cosine_learning_rate = slope * float(\n                step) + self._warmup_learning_rate\n            if step < self._warmup_steps:\n                return pre_cosine_learning_rate\n            else:\n                return learning_rate\n\n\nclass OneCycle(_LRSchedulerStep):\n    def __init__(self,\n                 optimizer,\n                 total_steps,\n                 lr_max,\n                 moms,\n                 div_factor=25,\n                 pct_start=0.3,\n                 last_step=-1):\n        if total_steps < warmup_steps:\n            raise ValueError('total_steps must be larger or equal to '\n                             'warmup_steps.')\n        self._total_steps = total_steps\n        self._lr_max = lr_max\n        self._moms = moms\n\n        self._warmup_learning_rate = warmup_learning_rate\n        self._warmup_steps = warmup_steps\n\n        super().__init__(optimizer, last_step)\n\n    def _get_lr_per_group(self, base_lr):\n        if base_lr < self._warmup_learning_rate:\n            raise ValueError('learning_rate_base must be larger '\n                             'or equal to warmup_learning_rate.')\n\n        step = self.last_step\n        learning_rate = 0.5 * base_lr * (\n            1 + np.cos(np.pi *\n                       (float(step) - self._warmup_steps\n                        ) / float(self._total_steps - self._warmup_steps)))\n        if self._warmup_steps > 0:\n            slope = (base_lr - self._warmup_learning_rate) / self._warmup_steps\n            pre_cosine_learning_rate = slope * float(\n                step) + self._warmup_learning_rate\n            if step < self._warmup_steps:\n                return pre_cosine_learning_rate\n            else:\n                return learning_rate\n"
  },
  {
    "path": "torchplus/train/learning_schedules_fastai.py",
    "content": "import numpy as np\nimport math\nfrom functools import partial\nimport torch\n\n\nclass LRSchedulerStep(object):\n    def __init__(self, fai_optimizer, total_step, lr_phases, mom_phases):\n        self.optimizer = fai_optimizer\n        self.total_step = total_step\n        self.lr_phases = []\n\n        for i, (start, lambda_func) in enumerate(lr_phases):\n            if len(self.lr_phases) != 0:\n                assert self.lr_phases[-1][0] < int(start * total_step)\n            if isinstance(lambda_func, str):\n                lambda_func = eval(lambda_func)\n            if i < len(lr_phases) - 1:\n                self.lr_phases.append((int(start * total_step),\n                                       int(lr_phases[i + 1][0] * total_step),\n                                       lambda_func))\n            else:\n                self.lr_phases.append((int(start * total_step), total_step,\n                                       lambda_func))\n        assert self.lr_phases[0][0] == 0\n        self.mom_phases = []\n\n        for i, (start, lambda_func) in enumerate(mom_phases):\n            if len(self.mom_phases) != 0:\n                assert self.mom_phases[-1][0] < int(start * total_step)\n            if isinstance(lambda_func, str):\n                lambda_func = eval(lambda_func)\n            if i < len(mom_phases) - 1:\n                self.mom_phases.append((int(start * total_step),\n                                        int(mom_phases[i + 1][0] * total_step),\n                                        lambda_func))\n            else:\n                self.mom_phases.append((int(start * total_step), total_step,\n                                        lambda_func))\n        if len(mom_phases) > 0:\n            assert self.mom_phases[0][0] == 0\n\n    def step(self, step):\n        lrs = []\n        moms = []\n        for start, end, func in self.lr_phases:\n            if step >= start:\n                # func: lr decay function\n                lrs.append(func((step - start) / (end - start)))\n        if len(lrs) > 0:\n            self.optimizer.lr = lrs[-1]\n            # import pdb \n            # pdb.set_trace()\n            # print(self.optimizer.lr,'!!!!')\n\n        for start, end, func in self.mom_phases:\n            if step >= start:\n                moms.append(func((step - start) / (end - start)))\n                self.optimizer.mom = func((step - start) / (end - start))\n        if len(moms) > 0:\n            self.optimizer.mom = moms[-1]\n\n    @property\n    def learning_rate(self):\n        return self.optimizer.lr\n\n\ndef annealing_cos(start, end, pct):\n    # print(pct, start, end)\n    \"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.\"\n    cos_out = np.cos(np.pi * pct) + 1\n    return end + (start - end) / 2 * cos_out\n\n\nclass OneCycle(LRSchedulerStep):\n    def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor,\n                 pct_start):\n        self.lr_max = lr_max\n        self.moms = moms\n        self.div_factor = div_factor\n        self.pct_start = pct_start\n        a1 = int(total_step * self.pct_start)\n        a2 = total_step - a1\n        low_lr = self.lr_max / self.div_factor\n        lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)),\n                     (self.pct_start,\n                      partial(annealing_cos, self.lr_max, low_lr / 1e4))\n                      )\n        mom_phases = ((0, partial(annealing_cos, *self.moms)),\n                      (self.pct_start, partial(annealing_cos,\n                                               *self.moms[::-1])))\n        fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0]\n\n        super().__init__(fai_optimizer, total_step, lr_phases, mom_phases)\n\n\nclass ExponentialDecayWarmup(LRSchedulerStep):\n    def __init__(self,\n                 fai_optimizer,\n                 total_step,\n                 initial_learning_rate,\n                 decay_length,\n                 decay_factor,\n                 div_factor=1,\n                 pct_start=0,\n                 staircase=True):\n        \"\"\"\n        Args:\n            decay_length: must in (0, 1)\n        \"\"\"\n        assert decay_length > 0\n        assert decay_length < 1\n        self._decay_steps_unified = decay_length\n        self._decay_factor = decay_factor\n        self._staircase = staircase\n        self.div_factor = div_factor\n        self.pct_start = pct_start\n        step = pct_start*total_step  # 0\n        stage = 1\n        lr_phases = [\n            (0, partial(annealing_cos, initial_learning_rate/div_factor, initial_learning_rate))]\n        if staircase:\n            while step <= total_step:\n                func = lambda p, _d=initial_learning_rate * stage: _d\n                lr_phases.append((step / total_step, func))\n                stage *= decay_factor\n                step += int(decay_length * total_step)\n        else:\n            def func(p): return pow(decay_factor, (p / decay_length))\n            lr_phases.append((pct_start, func))\n            # lr_phases.append((step/total_step, func))\n        super().__init__(fai_optimizer, total_step, lr_phases, [])\n\n\nclass ExponentialDecay(LRSchedulerStep):\n    def __init__(self,\n                 fai_optimizer,\n                 total_step,\n                 initial_learning_rate,\n                 decay_length,\n                 decay_factor,\n                 staircase=True):\n        \"\"\"\n        Args:\n            decay_length: must in (0, 1)\n        \"\"\"\n        assert decay_length > 0\n        assert decay_length < 1\n        self._decay_steps_unified = decay_length\n        self._decay_factor = decay_factor\n        self._staircase = staircase\n        step = 0\n        stage = 1\n        lr_phases = []\n        if staircase:\n            while step <= total_step:\n                func = lambda p, _d=initial_learning_rate * stage: _d\n                lr_phases.append((step / total_step, func))\n                stage *= decay_factor\n                step += int(decay_length * total_step)\n        else:\n            def func(p): return pow(decay_factor, (p / decay_length))\n            lr_phases.append((0, func))\n        super().__init__(fai_optimizer, total_step, lr_phases, [])\n\n\nclass ManualStepping(LRSchedulerStep):\n    def __init__(self, fai_optimizer, total_step, boundaries, rates):\n        assert all([b > 0 and b < 1 for b in boundaries])\n        assert len(boundaries) + 1 == len(rates)\n        boundaries.insert(0, 0.0)\n        lr_phases = []\n        for start, rate in zip(boundaries, rates):\n            def func(p, _d=rate): return _d\n            lr_phases.append((start, func))\n        super().__init__(fai_optimizer, total_step, lr_phases, [])\n\n\nclass FakeOptim:\n    def __init__(self):\n        self.lr = 0\n        self.mom = 0\n\n\nif __name__ == \"__main__\":\n    import matplotlib.pyplot as plt\n    opt = FakeOptim()  # 3e-3, wd=0.4, div_factor=10\n    # schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.4)\n    schd = ExponentialDecay(opt, 100, 3e-4, 0.1, 0.8, staircase=True)\n    schd = ManualStepping(opt, 100, [0.8, 0.9], [0.001, 0.0001, 0.00005])\n    lrs = []\n    moms = []\n    for i in range(100):\n        schd.step(i)\n        lrs.append(opt.lr)\n        moms.append(opt.mom)\n\n    plt.plot(lrs)\n    # plt.plot(moms)\n    # plt.show()\n    # plt.plot(moms)\n    plt.show()\n"
  },
  {
    "path": "torchplus/train/optim.py",
    "content": "from collections import defaultdict, Iterable\n\nimport torch\nfrom copy import deepcopy\nfrom itertools import chain\nfrom torch.autograd import Variable\n\nrequired = object()\n\ndef param_fp32_copy(params):\n    param_copy = [\n        param.clone().type(torch.cuda.FloatTensor).detach() for param in params\n    ]\n    for param in param_copy:\n        param.requires_grad = True\n    return param_copy\n\ndef set_grad(params, params_with_grad, scale=1.0):\n    for param, param_w_grad in zip(params, params_with_grad):\n        if param.grad is None:\n            param.grad = torch.nn.Parameter(\n                param.data.new().resize_(*param.data.size()))\n        grad = param_w_grad.grad.data\n        if scale is not None:\n            grad /= scale\n        if torch.isnan(grad).any() or torch.isinf(grad).any():\n            return True # invalid grad\n        param.grad.data.copy_(grad)\n    return False\n\nclass MixedPrecisionWrapper(object):\n    \"\"\"mixed precision optimizer wrapper.\n    Arguments:\n        optimizer (torch.optim.Optimizer): an instance of \n            :class:`torch.optim.Optimizer`\n        scale: (float): a scalar for grad scale.\n        auto_scale: (bool): whether enable auto scale.\n            The algorihm of auto scale is discribled in \n            http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html\n    \"\"\"\n\n    def __init__(self,\n                 optimizer,\n                 scale=None,\n                 auto_scale=True,\n                 inc_factor=2.0,\n                 dec_factor=0.5,\n                 num_iters_be_stable=500):\n        # if not isinstance(optimizer, torch.optim.Optimizer):\n        #     raise ValueError(\"must provide a torch.optim.Optimizer\")\n        self.optimizer = optimizer\n        if hasattr(self.optimizer, 'name'):\n            self.name = self.optimizer.name  # for ckpt system\n        param_groups_copy = []\n        for i, group in enumerate(optimizer.param_groups):\n            group_copy = {n: v for n, v in group.items() if n != 'params'}\n            group_copy['params'] = param_fp32_copy(group['params'])\n            param_groups_copy.append(group_copy)\n\n        # switch param_groups, may be dangerous\n        self.param_groups = optimizer.param_groups\n        optimizer.param_groups = param_groups_copy\n        self.grad_scale = scale\n        self.auto_scale = auto_scale\n        self.inc_factor = inc_factor\n        self.dec_factor = dec_factor\n        self.stable_iter_count = 0\n        self.num_iters_be_stable = num_iters_be_stable\n\n    def __getstate__(self):\n        return self.optimizer.__getstate__()\n\n    def __setstate__(self, state):\n        return self.optimizer.__setstate__(state)\n\n    def __repr__(self):\n        return self.optimizer.__repr__()\n\n    def state_dict(self):\n        return self.optimizer.state_dict()\n\n    def load_state_dict(self, state_dict):\n        return self.optimizer.load_state_dict(state_dict)\n\n    def zero_grad(self):\n        return self.optimizer.zero_grad()\n\n    def step(self, closure=None):\n        for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):\n            invalid = set_grad(g_copy['params'], g['params'], self.grad_scale)\n            if invalid:\n                if self.grad_scale is None or self.auto_scale is False:\n                    raise ValueError(\"nan/inf detected but auto_scale disabled.\")\n                self.grad_scale *= self.dec_factor\n                print('scale decay to {}'.format(self.grad_scale))\n                return\n        if self.auto_scale is True:\n            self.stable_iter_count += 1\n            if self.stable_iter_count > self.num_iters_be_stable:\n                if self.grad_scale is not None:\n                    self.grad_scale *= self.inc_factor\n                self.stable_iter_count = 0\n\n        if closure is None:\n            self.optimizer.step()\n        else:\n            self.optimizer.step(closure)\n        for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):\n            for p_copy, p in zip(g_copy['params'], g['params']):\n                p.data.copy_(p_copy.data)\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/config_io.py",
    "content": "\nfrom easydict import EasyDict as edict\nimport os\nimport shutil\nimport yaml\n\n\ndef mkdir_if_not_exists(path):\n    \"\"\"Make a directory if it does not exist.\n    Args:\n        path: directory to create\n    \"\"\"\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef read_yaml(filename):\n    \"\"\"Load yaml file as a dictionary item\n    Args:\n        filename (str): .yaml file path\n    Returns:\n        cfg (dict): configuration\n    \"\"\"\n    if filename is not None:\n        with open(filename, 'r') as f:\n            return yaml.load(f, Loader=yaml.FullLoader)\n    else:\n        return {}\n\n\ndef copy_file(src_file, tgt_file):\n    \"\"\"Copy a file\n    Args:\n        src_file (str): source file\n        tgt_file (str): target file\n    \"\"\"\n    shutil.copyfile(src_file, tgt_file)\n\n\ndef update_dict(dict1, dict2, intersection=False):\n    \"\"\"update dict1 according to dict2\n    Args:\n        dict1 (dict): reference dictionary\n        dict2 (dict): new dictionary\n    return\n        dict1 (dict): updated reference dictionary\n    \"\"\"\n    for item in dict2:\n        # if dict1.get(item, -1) != -1:\n        if item in dict1:\n            if isinstance(dict1[item], dict):\n                dict1[item] = update_dict(dict1[item], dict2[item], intersection)\n            else:\n                dict1[item] = dict2[item]\n        else:\n            if not intersection:\n                dict1[item] = dict2[item]\n            else:\n                raise ValueError(f\"Key '{item}' is in the second dict but not in the first dict!\")\n    #inverse check\n    for item in dict1:\n        if item not in dict2:\n            print(f\"Warning: key {item} is not given and will use the default values\")\n    \n    return dict1\n\n\ndef merge_cfg(cfg_files, intersection=False):\n    \"\"\"merge default configuration and custom configuration\n    Args:\n        cfg_files (str): configuration file paths [default, custom]\n    Returns:\n        cfg (edict): merged EasyDict\n    \"\"\"\n    edict_items = []\n    # cfg = {}\n    cfg = read_yaml(cfg_files[0])\n    # for f in cfg_files:\n    for f in cfg_files[1:]:\n        # if f is not None:\n        if os.path.exists(f):\n            cfg = update_dict(cfg, read_yaml(f), intersection)\n        else:\n            raise ValueError(f\"File {f} does not exist.\" )\n    return edict(cfg)\n\n\ndef write_cfg(default, custom, f, level_cnt=0):\n    \"\"\"write configuration to file\n    Args:\n        default (dict): default configuration dictionary\n        custom (dict): custom configuration dictionary\n        file (TextIOWrapper)\n    \"\"\"\n    offset_len = 100\n    for item in default:\n        if isinstance(default[item], dict):\n            if custom.get(item, -1) == -1:\n                custom[item] = {}\n            line = \"  \"*level_cnt + item + \": \"\n            offset = offset_len - len(line)\n            line += \" \"*offset + \" # |\"\n            f.writelines(line + \"\\n\")\n            write_cfg(default[item], custom[item], f, level_cnt+1)\n        else:\n            line = \"  \" * level_cnt + item + \": \"\n            if custom.get(item, -1) == -1:\n                if default[item] is not None:\n                    line += str(default[item])\n                offset = offset_len - len(line)\n                line += \" \"*offset + \" # | \"\n            else: \n                if custom[item] is not None:\n                    line += str(custom[item])\n                offset = offset_len - len(line)\n                line += \" \"*offset + \" # | \"\n                if custom[item] != default[item]:\n                    line += str(default[item])\n            f.writelines(line)\n            f.writelines(\"\\n\")\n\n\ndef save_cfg(cfg_files, file_path):\n    \"\"\"Save configuration file\n    Args:\n        cfg_files (str): configuration file paths [default, custom]\n    Returns:\n        cfg (edict): merged EasyDict\n    \"\"\"\n    # read configurations\n    default = read_yaml(cfg_files[0])\n    custom = read_yaml(cfg_files[1])\n\n    # create file to be written\n    f = open(file_path, 'w')\n\n    # write header line\n    line = \"# \" + \"-\"*20 + \" Setup \" + \"-\"*74\n    line += \"|\" + \"-\"*10 + \" Default \" + \"-\"*20 + \"\\n\"\n    f.writelines(line)\n\n    # write configurations\n    write_cfg(default, custom, f)\n    f.close()\n"
  },
  {
    "path": "utils/distributed_utils.py",
    "content": "from torch.utils.data import Sampler\nimport math\nimport os\nimport pdb\nimport torch\nimport torch.distributed as dist\nfrom torch.nn import Module\nimport multiprocessing as mp\nimport numpy as np\n\n\nclass ParallelWrapper(Module):\n    def __init__(self, net, parallel_mode='none'):\n        super(ParallelWrapper, self).__init__()\n        assert parallel_mode in ['dist', 'data_parallel', 'none']\n        self.parallel_mode = parallel_mode\n        if parallel_mode == 'none':\n            self.net = net\n            self.module = net\n        elif parallel_mode == 'dist':\n            self.net = DistModule(net)\n            self.module = self.net.module\n        else:\n            self.net = torch.nn.DataParallel(net)\n            self.module = self.net.module\n\n    def forward(self, *inputs, **kwargs):\n        return self.net.forward(*inputs, **kwargs)\n\n    def train(self, mode=True):\n        super(ParallelWrapper, self).train(mode)\n        self.net.train(mode)\n\n\nclass DistModule(Module):\n    def __init__(self, module):\n        super(DistModule, self).__init__()\n        self.module = module\n        broadcast_params(self.module)\n\n    def forward(self, *inputs, **kwargs):\n        return self.module(*inputs, **kwargs)\n\n    def train(self, mode=True):\n        super(DistModule, self).train(mode)\n        self.module.train(mode)\n\ndef gradients_multiply(model, multiplier=1):\n    for param in model.parameters():\n        if param.requires_grad and param.grad is not None:\n            param.grad.data *= multiplier\n    \ndef average_gradients(model):\n    \"\"\" average gradients \"\"\"\n\n    # for n, param, in model.named_parameters():\n    #     if 'dynamic_sigma' in n:\n    #         print(param.requires_grad, param.grad.data, param.data)\n    #     if param.requires_grad and param.grad is not None:\n    #         dist.all_reduce(param.grad.data)\n\n    for param in model.parameters():\n        if param.requires_grad and param.grad is not None:\n            dist.all_reduce(param.grad.data)\n            # param.grad.data *= multiplier\n\n\ndef broadcast_params(model):\n    \"\"\" broadcast model parameters \"\"\"\n    for p in model.state_dict().values():\n        dist.broadcast(p, 0)\n\n\ndef dist_init(port):\n    # os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n    if mp.get_start_method(allow_none=True) != 'spawn':\n        mp.set_start_method('spawn', force=True)\n    proc_id = int(os.environ['SLURM_PROCID'])\n    ntasks = int(os.environ['SLURM_NTASKS'])\n    node_list = os.environ['SLURM_NODELIST']\n    num_gpus = torch.cuda.device_count()\n    torch.cuda.set_device(proc_id % num_gpus)\n\n    if '[' in node_list:\n        beg = node_list.find('[')\n        pos1 = node_list.find('-', beg)\n        if pos1 < 0:\n            pos1 = 1000\n        pos2 = node_list.find(',', beg)\n        if pos2 < 0:\n            pos2 = 1000\n        node_list = node_list[:min(pos1, pos2)].replace('[', '')\n    #added by dy\n    # addr = node_list[8:].replace('-', '.')\n    addr = node_list.replace('-', '.')\n    if ',' in addr:\n        addr = addr.split(',')[0]\n    addr = addr[8:]\n    # addr = ','.join([ad[8:] for ad in addrs ])\n\n    print(addr)\n\n    os.environ['MASTER_PORT'] = port\n    os.environ['MASTER_ADDR'] = addr\n    os.environ['WORLD_SIZE'] = str(ntasks)\n    os.environ['RANK'] = str(proc_id)\n    dist.init_process_group(backend='nccl')\n\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n    return rank, world_size\n\n\n# from . import Sampler\n\n\nclass DistributedSequatialSampler(Sampler):\n    \"\"\"Sampler that restricts data loading to a subset of the dataset.\n\n    It is especially useful in conjunction with\n    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each\n    process can pass a DistributedSampler instance as a DataLoader sampler,\n    and load a subset of the original dataset that is exclusive to it.\n\n    .. note::\n        Dataset is assumed to be of constant size.\n\n    Arguments:\n        dataset: Dataset used for sampling.\n        num_replicas (optional): Number of processes participating in\n            distributed training.\n        rank (optional): Rank of the current process within num_replicas.\n    \"\"\"\n\n    def __init__(self, dataset, num_replicas=None, rank=None):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\n                    \"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\n                    \"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.epoch = 0\n        self.num_samples = int(\n            math.ceil(len(self.dataset) * 1.0 / self.num_replicas))\n        self.total_size = self.num_samples * self.num_replicas\n\n    def __iter__(self):\n        # deterministically shuffle based on epoch\n        # g = torch.Generator()\n        # g.manual_seed(self.epoch)\n        # indices = torch.randperm(len(self.dataset), generator=g).tolist()\n\n        indices = list(range(len(self.dataset)))\n        # add extra samples to make it evenly divisible\n        indices += indices[:(self.total_size - len(indices))]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n\n    def __len__(self):\n        return self.num_samples\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n\nclass DistributedGivenIterationSampler(Sampler):\n    def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):\n        if world_size is None:\n            world_size = dist.get_world_size()  # link.get_world_size()\n        if rank is None:\n            rank = dist.get_rank()  # link.get_rank()\n        assert rank < world_size\n        self.dataset = dataset\n        self.total_iter = total_iter\n        self.batch_size = batch_size\n        self.world_size = world_size\n        self.rank = rank\n        self.last_iter = last_iter\n\n        self.total_size = self.total_iter*self.batch_size\n\n        self.indices = self.gen_new_list()\n        self.call = 0\n\n    def __iter__(self):\n        if self.call == 0:\n            self.call = 1\n            return iter(self.indices[(self.last_iter+1)*self.batch_size:])\n        else:\n            return iter(self.indices[(self.last_iter+1)*self.batch_size:])\n            raise RuntimeError(\n                \"this sampler is not designed to be called more than once!!\")\n\n    def gen_new_list(self):\n\n        # each process shuffle all list with same seed, and pick one piece according to rank\n        # np.random.seed(0)\n        np.random.seed(7)\n\n        all_size = self.total_size * self.world_size\n        indices = np.arange(len(self.dataset))\n        indices = indices[:all_size]\n        num_repeat = (all_size-1) // indices.shape[0] + 1\n        indices = np.tile(indices, num_repeat)\n        indices = indices[:all_size]\n\n        np.random.shuffle(indices)\n        beg = self.total_size * self.rank\n        indices = indices[beg:beg+self.total_size]\n\n        assert len(indices) == self.total_size\n\n        return indices\n\n    def __len__(self):\n        # note here we do not take last iter into consideration, since __len__\n        # should only be used for displaying, the correct remaining size is\n        # handled by dataloader\n        # return self.total_size - (self.last_iter+1)*self.batch_size\n        return self.total_size\n\n    def set_epoch(self, epoch):\n        pass\n\n\nclass DistributedGivenIterationSamplerEpoch(Sampler):\n    def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1, review_cycle=-1):\n        if world_size is None:\n            world_size = dist.get_world_size()  # link.get_world_size()\n        if rank is None:\n            rank = dist.get_rank()  # link.get_rank()\n        assert rank < world_size\n        self.dataset = dataset\n        self.total_iter = total_iter\n        self.batch_size = batch_size\n        self.world_size = world_size\n        self.rank = rank\n        self.last_iter = last_iter\n\n        self.total_size = self.total_iter*self.batch_size\n        self.review_cycle = review_cycle # in unit of epoch\n\n        self.indices = self.gen_new_list()\n        self.call = 0\n    def __iter__(self):\n        if self.call == 0:\n            self.call = 1\n            return iter(self.indices[(self.last_iter+1)*self.batch_size:])\n        else:\n            return iter(self.indices[(self.last_iter+1)*self.batch_size:])\n            # raise RuntimeError(\n            #     \"this sampler is not designed to be called more than once!!\")\n\n    def gen_new_list(self):\n\n        # each process shuffle all list with same seed, and pick one piece according to rank\n        # np.random.seed(0)\n        np.random.seed(7)\n\n        all_size = self.total_size * self.world_size\n        indices = np.arange(len(self.dataset))\n        indices = indices[:all_size]\n        num_repeat = (all_size-1) // indices.shape[0] + 1\n\n        # indices = np.tile(indices, num_repeat)\n        indices = np.concatenate([np.random.permutation(indices) for i in range(num_repeat) ] )\n        seeds = np.arange(indices.size).reshape(indices.shape)\n\n        if self.review_cycle>0:\n            assert (1/self.review_cycle)%1==0\n            # review_freq = 1/1/self.review_cycle\n            h = len(indices) // int(self.review_cycle*len(self.dataset))\n            \n            # print(indices.shape,'???!!!',  indices[:h*int(self.review_cycle*len(self.dataset) ) ].shape)\n            indices = indices[:h*int(self.review_cycle*len(self.dataset) ) ].reshape([h,-1] )\n            seeds = seeds[:h*int(self.review_cycle*len(self.dataset) ) ].reshape([h,-1] )\n\n            indices = np.concatenate([indices, indices], axis=1).reshape(-1)\n            seeds = np.concatenate([seeds, seeds], axis=1).reshape(-1)\n\n        indices = indices[:all_size]\n        seeds = seeds[:all_size]\n\n        # np.random.shuffle(indices)\n        beg = self.total_size * self.rank\n        indices = indices[beg:beg+self.total_size]\n        seeds = seeds[beg:beg+self.total_size]\n\n        assert len(indices) == self.total_size\n\n        # return indices\n        return list(zip(list(indices), list(seeds)))\n\n    def __len__(self):\n        # note here we do not take last iter into consideration, since __len__\n        # should only be used for displaying, the correct remaining size is\n        # handled by dataloader\n        # return self.total_size - (self.last_iter+1)*self.batch_size\n        return self.total_size\n\n    def set_epoch(self, epoch):\n        pass\n"
  },
  {
    "path": "utils/eval_metric.py",
    "content": "import os\nimport numpy as np\nfrom plyfile import PlyData\n# from utils import icp_utils\nfrom data.linemod import linemod_config\nfrom thirdparty.vsd import inout\nfrom thirdparty.nn import nn_utils\nfrom utils.img_utils import read_depth\nfrom thirdparty.kpconv.lib.utils import square_distance\nfrom utils.geometric import rotation_angle\nfrom utils.visualize import *\n# from thirdparty.fps import fps_utils\nimport torch\nimport open3d as o3d\nfrom transforms3d.quaternions import mat2quat, quat2mat, qmult\n# import data.bop_ycb.ycb_config as ycb_config #import bop_ycb_class2idx, model_info\n\ndef get_ply_model(model_path, scale=1):\n    ply = PlyData.read(model_path)\n    data = ply.elements[0].data\n    x = data['x']*scale\n    y = data['y']*scale\n    z = data['z']*scale\n    model = np.stack([x, y, z], axis=-1)\n    return model\n\n\ndef project(xyz, K, RT):\n    \"\"\"\n    xyz: [N, 3]\n    K: [3, 3]\n    RT: [3, 4]\n    \"\"\"\n    xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T\n    xyz = np.dot(xyz, K.T)\n    xy = xyz[:, :2] / xyz[:, 2:]\n    return xy\n\n\ndef find_nearest_point_idx(ref_pts, que_pts):\n    assert(ref_pts.shape[1] == que_pts.shape[1] and 1 < que_pts.shape[1] <= 3)\n    pn1 = ref_pts.shape[0]\n    pn2 = que_pts.shape[0]\n    dim = ref_pts.shape[1]\n\n    ref_pts = np.ascontiguousarray(ref_pts[None, :, :], np.float32)\n    que_pts = np.ascontiguousarray(que_pts[None, :, :], np.float32)\n    idxs = np.zeros([1, pn2], np.int32)\n\n    ref_pts_ptr = ffi.cast('float *', ref_pts.ctypes.data)\n    que_pts_ptr = ffi.cast('float *', que_pts.ctypes.data)\n    idxs_ptr = ffi.cast('int *', idxs.ctypes.data)\n    lib.findNearestPointIdxLauncher(\n        ref_pts_ptr, que_pts_ptr, idxs_ptr, 1, pn1, pn2, dim, 0)\n\n    return idxs[0]\n\n\nclass LineMODEvaluator:\n    def __init__(self, class_name, result_dir, icp_refine=False):\n\n        # self.result_dir = os.path.join(result_dir, cfg.test.dataset)\n        self.result_dir = os.path.join(result_dir, \"LINEMOD\")\n        os.system('mkdir -p {}'.format(self.result_dir))\n\n\n        # data_root = args['data_root']\n        # cls = cfg.cls_type\n        self.class_name = class_name\n        self.icp_refine = icp_refine\n\n        # model_path = os.path.join(os.path.dirname(os.path.abspath(\n        #     __file__)), '../EXPDATA/LINEMOD', class_name, class_name + '.ply')\n        model_path = os.path.join(os.path.dirname(os.path.abspath(\n            __file__)), '../EXPDATA/LM6d_converted/models', class_name, class_name + '.ply')\n        # self.model = pvnet_data_utils.get_ply_model(model_path)\n        self.model = get_ply_model(model_path)\n        self.diameter = linemod_config.diameters[class_name] / 100\n\n        self.proj2d = []\n        self.add = []\n        self.adds = [] #force sym\n        self.add2 = []\n        self.add5 = []\n        self.cmd5 = []\n\n        self.icp_proj2d = []\n        self.icp_add = []\n        self.icp_cmd5 = []\n\n        self.mask_ap = []\n        self.pose_preds=[]\n\n        self.height = 480\n        self.width = 640\n\n        model = inout.load_ply(model_path)\n        model['pts'] = model['pts'] * 1000\n        self.icp_refiner = icp_utils.ICPRefiner(\n            model, (self.width, self.height)) if icp_refine else None\n\n    def projection_2d(self, pose_pred, pose_targets, K, icp=False, threshold=5):\n        model_2d_pred = project(self.model, K, pose_pred)\n        model_2d_targets = project(self.model, K, pose_targets)\n        proj_mean_diff = np.mean(np.linalg.norm(\n            model_2d_pred - model_2d_targets, axis=-1))\n        if icp:\n            self.icp_proj2d.append(proj_mean_diff < threshold)\n        else:\n            self.proj2d.append(proj_mean_diff < threshold)\n\n    def projection_2d_sym(self, pose_pred, pose_targets, K, threshold=5):\n        model_2d_pred = project(self.model, K, pose_pred)\n        model_2d_targets = project(self.model, K, pose_targets)\n        proj_mean_diff=np.mean(find_nearest_point_distance(model_2d_pred,model_2d_targets))\n\n        self.proj_mean_diffs.append(proj_mean_diff)\n        self.projection_2d_recorder.append(proj_mean_diff < threshold)\n\n    def add2_metric(self, pose_pred, pose_targets, icp=False, syn=False, percentage=0.02):\n        diameter = self.diameter * percentage\n        model_pred = np.dot(self.model, pose_pred[:, :3].T) + pose_pred[:, 3]\n        model_targets = np.dot(\n            self.model, pose_targets[:, :3].T) + pose_targets[:, 3]\n\n        if syn:\n            idxs = nn_utils.find_nearest_point_idx(model_pred, model_targets)\n            # idxs = find_nearest_point_idx(model_pred, model_targets)\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred[idxs] - model_targets, 2, 1))\n        else:\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred - model_targets, axis=-1))\n\n        if icp:\n            self.icp_add.append(mean_dist < diameter)\n        else:\n            self.add2.append(mean_dist < diameter)\n\n    def add5_metric(self, pose_pred, pose_targets, icp=False, syn=False, percentage=0.05):\n        diameter = self.diameter * percentage\n        model_pred = np.dot(self.model, pose_pred[:, :3].T) + pose_pred[:, 3]\n        model_targets = np.dot(\n            self.model, pose_targets[:, :3].T) + pose_targets[:, 3]\n\n        if syn:\n            idxs = nn_utils.find_nearest_point_idx(model_pred, model_targets)\n            # idxs = find_nearest_point_idx(model_pred, model_targets)\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred[idxs] - model_targets, 2, 1))\n        else:\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred - model_targets, axis=-1))\n\n        if icp:\n            self.icp_add.append(mean_dist < diameter)\n        else:\n            self.add5.append(mean_dist < diameter)\n\n    \n    def add_metric(self, pose_pred, pose_targets, icp=False, syn=False, percentage=0.1):\n        diameter = self.diameter * percentage\n        model_pred = np.dot(self.model, pose_pred[:, :3].T) + pose_pred[:, 3]\n        model_targets = np.dot(\n            self.model, pose_targets[:, :3].T) + pose_targets[:, 3]\n\n        if syn:\n            idxs = nn_utils.find_nearest_point_idx(model_pred, model_targets)\n            # idxs = find_nearest_point_idx(model_pred, model_targets)\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred[idxs] - model_targets, 2, 1))\n        else:\n            mean_dist = np.mean(np.linalg.norm(\n                model_pred - model_targets, axis=-1))\n\n        if icp:\n            self.icp_add.append(mean_dist < diameter)\n        else:\n            self.add.append(mean_dist < diameter)\n\n    def cm_degree_5_metric(self, pose_pred, pose_targets, icp=False):\n        translation_distance = np.linalg.norm(\n            pose_pred[:, 3] - pose_targets[:, 3]) * 100\n        rotation_diff = np.dot(pose_pred[:, :3], pose_targets[:, :3].T)\n        trace = np.trace(rotation_diff)\n        trace = trace if trace <= 3 else 3\n        angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.))\n        if icp:\n            self.icp_cmd5.append(translation_distance <\n                                 5 and angular_distance < 5)\n        else:\n            self.cmd5.append(translation_distance < 5 and angular_distance < 5)\n\n    def mask_iou(self, output, batch):\n        mask_pred = torch.argmax(output['seg'], dim=1)[\n            0].detach().cpu().numpy()\n        mask_gt = batch['mask'][0].detach().cpu().numpy()\n        iou = (mask_pred & mask_gt).sum() / (mask_pred | mask_gt).sum()\n        self.mask_ap.append(iou > 0.7)\n\n    def icp_refine(self, pose_pred, anno, output, K):\n        depth = read_depth(anno['depth_path'])\n        mask = torch.argmax(output['seg'], dim=1)[0].detach().cpu().numpy()\n        if pose_pred[2, 3] <= 0:\n            return pose_pred\n        depth[mask != 1] = 0\n        pose_pred_tmp = pose_pred.copy()\n        pose_pred_tmp[:3, 3] = pose_pred_tmp[:3, 3] * 1000\n\n        R_refined, t_refined = self.icp_refiner.refine(\n            depth, pose_pred_tmp[:3, :3], pose_pred_tmp[:3, 3], K.copy(), depth_only=True, max_mean_dist_factor=5.0)\n        R_refined, _ = self.icp_refiner.refine(\n            depth, R_refined, t_refined, K.copy(), no_depth=True)\n\n        pose_pred = np.hstack((R_refined, t_refined.reshape((3, 1)) / 1000))\n\n        return pose_pred\n\n\n    def icp_refine_(self, pose, anno, output):\n        depth = read_depth(anno['depth_path']).astype(np.uint16)\n        mask = torch.argmax(output['seg'], dim=1)[0].detach().cpu().numpy()\n        mask = mask.astype(np.int32)\n        pose = pose.astype(np.float32)\n\n        poses = np.zeros([1, 7], dtype=np.float32)\n        poses[0, :4] = mat2quat(pose[:, :3])\n        poses[0, 4:] = pose[:, 3]\n\n        poses_new = np.zeros([1, 7], dtype=np.float32)\n        poses_icp = np.zeros([1, 7], dtype=np.float32)\n\n        fx = 572.41140\n        fy = 573.57043\n        px = 325.26110\n        py = 242.04899\n        zfar = 6.0\n        znear = 0.25\n        factor = 1000.0\n        error_threshold = 0.01\n\n        rois = np.zeros([1, 6], dtype=np.float32)\n        rois[:, :] = 1\n\n        self.icp_refiner.solveICP(mask, depth,\n                                  self.height, self.width,\n                                  fx, fy, px, py,\n                                  znear, zfar,\n                                  factor,\n                                  rois.shape[0], rois,\n                                  poses, poses_new, poses_icp,\n                                  error_threshold\n                                  )\n\n        pose_icp = np.zeros([3, 4], dtype=np.float32)\n        pose_icp[:, :3] = quat2mat(poses_icp[0, :4])\n        pose_icp[:, 3] = poses_icp[0, 4:]\n\n        return pose_icp\n\n    def summarize(self):\n        proj2d = np.mean(self.proj2d)\n        add = np.mean(self.add)\n        # adds = np.mean(self.adds)\n        add2 = np.mean(self.add2)\n        add5 = np.mean(self.add5)\n        cmd5 = np.mean(self.cmd5)\n        ap = np.mean(self.mask_ap)\n        seq_len=len(self.add)\n        print('2d projections metric: {}'.format(proj2d))\n        print('ADD metric: {}'.format(add))\n        print('ADD2 metric: {}'.format(add2))\n        print('ADD5 metric: {}'.format(add5))\n        # print('ADDS metric: {}'.format(adds))\n        print('5 cm 5 degree metric: {}'.format(cmd5))\n        print('mask ap70: {}'.format(ap))\n        print('seq_len: {}'.format(seq_len))\n        # if cfg.test.icp:\n        if self.icp_refine:\n            print('2d projections metric after icp: {}'.format(\n                np.mean(self.icp_proj2d)))\n            print('ADD metric after icp: {}'.format(np.mean(self.icp_add)))\n            print('5 cm 5 degree metric after icp: {}'.format(\n                np.mean(self.icp_cmd5)))\n        self.proj2d = []\n        self.add = []\n        self.add2 = []\n        self.add5 = []\n        # self.adds = []\n        self.cmd5 = []\n        self.mask_ap = []\n        self.icp_proj2d = []    \n        self.icp_add = []\n        self.icp_cmd5 = []\n        \n\n        #save pose predictions\n        if len(self.pose_preds)> 0:\n            np.save(f\"{self.class_name}_pose_preds.npy\",self.pose_preds)\n        self.pose_preds=[]\n\n        return {'proj2d': proj2d, 'add': add, 'add2': add2, 'add5': add5,'cmd5': cmd5, 'ap': ap, \"seq_len\": seq_len}\n\n\n\n    def evaluate_rnnpose(self, preds_dict, example): # sample_correspondence_pairs=False, direct_align=False, use_cnnpose=True):\n        len_src_f = example['stack_lengths'][0][0]\n        # lifted_points = example['lifted_points'].squeeze(0)\n        assert len( example['lifted_points']) == 1, \"TODO: support bs>1\"\n        lifted_points = example['lifted_points'][0].squeeze(0)\n        model_points = example['original_model_points'][:len_src_f]\n\n        K = example[\"K\"].cpu().numpy().squeeze()\n        R_pred = preds_dict['Ti_pred'].G[:,0, :3,:3].squeeze().detach().cpu().numpy()\n        t_pred = preds_dict['Ti_pred'].G[:,0, :3,3:].squeeze(0).detach().cpu().numpy()\n        pose_pred= preds_dict['Ti_pred'].G[:,0, :3].squeeze().detach().cpu().numpy()\n#             print(example['POSECNN_RT'].dtype, example['rendered_RT'].dtype, flush=True)\n#             R_pred = example['POSECNN_RT'][:,:3,:3].squeeze().detach().cpu().numpy()\n#             t_pred = example['POSECNN_RT'][:,:3,3:].squeeze(0).detach().cpu().numpy()\n#             pose_pred= example['POSECNN_RT'][:, :3].squeeze().detach().cpu().numpy()\n\n\n        pose_gt = example['original_RT'].squeeze()[:3].cpu().numpy()\n        \n        \n        ang_err = rotation_angle(pose_gt[:3, :3], R_pred)\n        trans_err = np.linalg.norm(t_pred-pose_gt[:3, -1:])  # 3x1\n\n        if self.class_name in ['eggbox', 'glue']:\n            self.add_metric(pose_pred, pose_gt, syn=True)\n            self.add2_metric(pose_pred, pose_gt, syn=True)\n            self.add5_metric(pose_pred, pose_gt, syn=True)\n        else:\n            self.add_metric(pose_pred, pose_gt)\n            self.add2_metric(pose_pred, pose_gt)\n            self.add5_metric(pose_pred, pose_gt)\n\n        self.projection_2d(pose_pred, pose_gt, K=linemod_config.linemod_K)\n        self.cm_degree_5_metric(pose_pred, pose_gt)\n        # self.mask_iou(output, batch)\n\n        # vis\n        pc_proj_vis = vis_pointclouds_cv2((pose_gt[:3, :3]@model_points.cpu().numpy(\n        ).T+pose_gt[:3, -1:]).T, example[\"K\"].cpu().numpy().squeeze(), [480,640])\n        pc_proj_vis_pred = vis_pointclouds_cv2((pose_pred[:3, :3]@model_points.cpu().numpy(\n        ).T+pose_pred[:3, -1:]).T, example[\"K\"].cpu().numpy().squeeze(), [ 480, 640])\n\n\n        return {\n            \"ang_err\": ang_err,\n            \"trans_err\": trans_err,\n            \"pnp_inliers\": -1,#len(inliers),\n            \"pc_proj_vis\": pc_proj_vis,\n            \"pc_proj_vis_pred\": pc_proj_vis_pred,\n            \"keypoints_2d_vis\": np.zeros_like(pc_proj_vis_pred) #keypoints_2d_vis\n        }\n\n        \n\n\n# class YCBEvaluator:\n#     def __init__(self, class_name, result_dir, icp_refine=False):\n\n#         self.result_dir = os.path.join(result_dir, \"LINEMOD\")\n#         os.system('mkdir -p {}'.format(self.result_dir))\n\n#         self.class_name = class_name\n#         self.icp_refine = icp_refine\n        \n#         model_path = os.path.join(os.path.dirname(os.path.abspath(\n#             __file__)), '../EXPDATA/BOP_YCB/models', f'obj_{ycb_config.bop_ycb_class2idx[class_name]:06d}.ply' )\n#         self.model = get_ply_model(model_path, scale=0.001)\n#         # self.diameter = linemod_config.diameters[class_name] / 100\n#         self.diameter = ycb_config.model_info[ str(ycb_config.bop_ycb_class2idx[class_name]) ][\"diameter\"]*0.001  # in mm # / 1000\n\n#         self.proj2d = []\n#         self.add = []\n#         self.adds=[]\n#         self.cmd5 = []\n#         self.add_dist=[]\n#         self.adds_dist=[]\n\n#         self.icp_proj2d = []\n#         self.icp_add = []\n#         self.icp_cmd5 = []\n\n#         self.mask_ap = []\n#         self.pose_preds=[]\n\n#         self.height = 480\n#         self.width = 640\n\n#         model = inout.load_ply(model_path)\n#         model['pts'] = model['pts'] * 1000\n#         # self.icp_refiner = icp_utils.ICPRefiner(model, (self.width, self.height)) if cfg.test.icp else None\n#         self.icp_refiner = icp_utils.ICPRefiner(\n#             model, (self.width, self.height)) if icp_refine else None\n#         self.direct_align_module = DirectAlignment(None)\n#         # if cfg.test.icp:\n#         #     self.icp_refiner = ext_.Synthesizer(os.path.realpath(model_path))\n#         #     self.icp_refiner.setup(self.width, self.height)\n\n#     def projection_2d(self, pose_pred, pose_targets, K, icp=False, threshold=5):\n#         model_2d_pred = project(self.model, K, pose_pred)\n#         model_2d_targets = project(self.model, K, pose_targets)\n#         proj_mean_diff = np.mean(np.linalg.norm(\n#             model_2d_pred - model_2d_targets, axis=-1))\n#         if icp:\n#             self.icp_proj2d.append(proj_mean_diff < threshold)\n#         else:\n#             self.proj2d.append(proj_mean_diff < threshold)\n\n#     def projection_2d_sym(self, pose_pred, pose_targets, K, threshold=5):\n#         model_2d_pred = project(self.model, K, pose_pred)\n#         model_2d_targets = project(self.model, K, pose_targets)\n#         proj_mean_diff=np.mean(find_nearest_point_distance(model_2d_pred,model_2d_targets))\n\n#         self.proj_mean_diffs.append(proj_mean_diff)\n#         self.projection_2d_recorder.append(proj_mean_diff < threshold)\n\n#     def add_metric(self, pose_pred, pose_targets, icp=False, syn=False, percentage=0.1):\n#         diameter = self.diameter * percentage\n#         model_pred = np.dot(self.model, pose_pred[:, :3].T) + pose_pred[:, 3]\n#         model_targets = np.dot(\n#             self.model, pose_targets[:, :3].T) + pose_targets[:, 3]\n\n#         if syn:\n#             idxs = nn_utils.find_nearest_point_idx(model_pred, model_targets)\n#             # idxs = find_nearest_point_idx(model_pred, model_targets)\n#             mean_dist = np.mean(np.linalg.norm(\n#                 model_pred[idxs] - model_targets, 2, 1))\n#         else:\n#             mean_dist = np.mean(np.linalg.norm(\n#                 model_pred - model_targets, axis=-1))\n#         self.add_dist.append(mean_dist)\n#         if icp:\n#             self.icp_add.append(mean_dist < diameter)\n#         else:\n#             self.add.append(mean_dist < diameter)\n#     def auc_add(self, max_thresh=0.1):\n#         add_dist = np.array(self.add_dist)\n#         interval=0.001\n#         acc=0\n#         for k in range(int(max_thresh/interval)):\n#             acc+= interval* np.sum( ((k+1)*interval)>=add_dist)/ add_dist.shape[0]\n\n#         return acc/max_thresh\n#     def auc_adds(self, max_thresh=0.1):\n#         add_dist = np.array(self.adds_dist)\n#         interval=0.001\n#         acc=0\n#         for k in range(int(max_thresh/interval)):\n#             acc+= interval* np.sum( ((k+1)*interval)>=add_dist )/ add_dist.shape[0]\n\n#         return acc/max_thresh\n#     def adds_metric(self, pose_pred, pose_targets, icp=False, syn=False, percentage=0.1):\n#         diameter = self.diameter * percentage\n#         model_pred = np.dot(self.model, pose_pred[:, :3].T) + pose_pred[:, 3]\n#         model_targets = np.dot(\n#             self.model, pose_targets[:, :3].T) + pose_targets[:, 3]\n\n#         if syn:\n#             idxs = nn_utils.find_nearest_point_idx(model_pred, model_targets)\n#             # idxs = find_nearest_point_idx(model_pred, model_targets)\n#             mean_dist = np.mean(np.linalg.norm(\n#                 model_pred[idxs] - model_targets, 2, 1))\n#         else:\n#             mean_dist = np.mean(np.linalg.norm(\n#                 model_pred - model_targets, axis=-1))\n#         self.adds_dist.append(mean_dist)\n#         if icp:\n#             self.icp_add.append(mean_dist < diameter)\n#         else:\n#             self.adds.append(mean_dist < diameter)\n\n#     def cm_degree_5_metric(self, pose_pred, pose_targets, icp=False):\n#         translation_distance = np.linalg.norm(\n#             pose_pred[:, 3] - pose_targets[:, 3]) * 100\n#         rotation_diff = np.dot(pose_pred[:, :3], pose_targets[:, :3].T)\n#         trace = np.trace(rotation_diff)\n#         trace = trace if trace <= 3 else 3\n#         angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.))\n#         if icp:\n#             self.icp_cmd5.append(translation_distance <\n#                                  5 and angular_distance < 5)\n#         else:\n#             self.cmd5.append(translation_distance < 5 and angular_distance < 5)\n\n#     def mask_iou(self, output, batch):\n#         mask_pred = torch.argmax(output['seg'], dim=1)[\n#             0].detach().cpu().numpy()\n#         mask_gt = batch['mask'][0].detach().cpu().numpy()\n#         iou = (mask_pred & mask_gt).sum() / (mask_pred | mask_gt).sum()\n#         self.mask_ap.append(iou > 0.7)\n\n#     def icp_refine(self, pose_pred, anno, output, K):\n#         depth = read_depth(anno['depth_path'])\n#         mask = torch.argmax(output['seg'], dim=1)[0].detach().cpu().numpy()\n#         if pose_pred[2, 3] <= 0:\n#             return pose_pred\n#         depth[mask != 1] = 0\n#         pose_pred_tmp = pose_pred.copy()\n#         pose_pred_tmp[:3, 3] = pose_pred_tmp[:3, 3] * 1000\n\n#         R_refined, t_refined = self.icp_refiner.refine(\n#             depth, pose_pred_tmp[:3, :3], pose_pred_tmp[:3, 3], K.copy(), depth_only=True, max_mean_dist_factor=5.0)\n#         R_refined, _ = self.icp_refiner.refine(\n#             depth, R_refined, t_refined, K.copy(), no_depth=True)\n\n#         pose_pred = np.hstack((R_refined, t_refined.reshape((3, 1)) / 1000))\n\n#         return pose_pred\n\n\n#     def icp_refine_(self, pose, anno, output):\n#         depth = read_depth(anno['depth_path']).astype(np.uint16)\n#         mask = torch.argmax(output['seg'], dim=1)[0].detach().cpu().numpy()\n#         mask = mask.astype(np.int32)\n#         pose = pose.astype(np.float32)\n\n#         poses = np.zeros([1, 7], dtype=np.float32)\n#         poses[0, :4] = mat2quat(pose[:, :3])\n#         poses[0, 4:] = pose[:, 3]\n\n#         poses_new = np.zeros([1, 7], dtype=np.float32)\n#         poses_icp = np.zeros([1, 7], dtype=np.float32)\n\n#         fx = 572.41140\n#         fy = 573.57043\n#         px = 325.26110\n#         py = 242.04899\n#         zfar = 6.0\n#         znear = 0.25\n#         factor = 1000.0\n#         error_threshold = 0.01\n\n#         rois = np.zeros([1, 6], dtype=np.float32)\n#         rois[:, :] = 1\n\n#         self.icp_refiner.solveICP(mask, depth,\n#                                   self.height, self.width,\n#                                   fx, fy, px, py,\n#                                   znear, zfar,\n#                                   factor,\n#                                   rois.shape[0], rois,\n#                                   poses, poses_new, poses_icp,\n#                                   error_threshold\n#                                   )\n\n#         pose_icp = np.zeros([3, 4], dtype=np.float32)\n#         pose_icp[:, :3] = quat2mat(poses_icp[0, :4])\n#         pose_icp[:, 3] = poses_icp[0, 4:]\n\n#         return pose_icp\n\n    \n\n#     def summarize(self):\n#         proj2d = np.mean(self.proj2d)\n#         add = np.mean(self.add)\n#         adds = np.mean(self.adds)\n#         cmd5 = np.mean(self.cmd5)\n#         ap = np.mean(self.mask_ap)\n#         try:\n#             auc_add=self.auc_add(max_thresh=0.1)\n#         except:\n#             auc_add=0\n#         try:\n#             auc_adds=self.auc_adds(max_thresh=0.1)\n#         except:\n#             auc_adds=0\n#         seq_len=len(self.add)\n#         print('2d projections metric: {}'.format(proj2d))\n#         print('ADD metric: {}'.format(add))\n#         print('AUC ADD metric: {}'.format(auc_add))\n#         print('ADDS metric: {}'.format(adds))\n#         print('AUC ADDS metric: {}'.format(auc_adds))\n#         print('5 cm 5 degree metric: {}'.format(cmd5))\n#         print('mask ap70: {}'.format(ap))\n#         print('seq_len: {}'.format(seq_len))\n#         # if cfg.test.icp:\n#         if self.icp_refine:\n#             print('2d projections metric after icp: {}'.format(\n#                 np.mean(self.icp_proj2d)))\n#             print('ADD metric after icp: {}'.format(np.mean(self.icp_add)))\n#             print('5 cm 5 degree metric after icp: {}'.format(\n#                 np.mean(self.icp_cmd5)))\n#         self.proj2d = []\n#         self.add = []\n#         self.adds = []\n#         self.cmd5 = []\n#         self.mask_ap = []\n#         self.icp_proj2d = []    \n#         self.icp_add = []\n#         self.icp_cmd5 = []\n#         self.add_dist = []\n#         self.adds_dist = []\n        \n\n#         #save pose predictions\n#         if len(self.pose_preds)> 0:\n#             np.save(f\"{self.class_name}_pose_preds.npy\",self.pose_preds)\n#         self.pose_preds=[]\n\n#         return {'proj2d': proj2d, 'add': add, 'adds': adds,'cmd5': cmd5, 'ap': ap, \"seq_len\": seq_len}\n\n#     def evaluate_flowpose(self, preds_dict, example, sample_correspondence_pairs=False, direct_align=False, use_cnnpose=True):\n#         len_src_f = example['stack_lengths'][0][0]\n#         # lifted_points = example['lifted_points'].squeeze(0)\n#         assert len( example['lifted_points']) == 1, \"TODO: support bs>1\"\n#         lifted_points = example['lifted_points'][0].squeeze(0)\n#         model_points = example['original_model_points'][:len_src_f]\n#         K = example[\"K\"].cpu().numpy().squeeze()\n\n#         if not use_cnnpose: # use pnp \n#             len_src_f = example['stack_lengths'][0][0]\n#             descriptors_2d = preds_dict['descriptors_2d']\n#             descriptors_3d = preds_dict['descriptors_3d'][:len_src_f]\n\n#             mask = cv2.erode(( example['depth'].detach().cpu().numpy().squeeze()>0).astype(np.uint8)*255, kernel=np.ones([3,3], np.uint8),iterations = 1)\n#             mask=torch.tensor(mask, device=descriptors_3d.device)\n#             ys_, xs_ = torch.nonzero(mask, as_tuple=True)\n\n#             # ys_, xs_ = torch.nonzero(example['depth'].squeeze(), as_tuple=True)\n#             descriptors_2d = descriptors_2d[:, :,\n#                                             ys_, xs_].squeeze().permute([1, 0])\n#             img_coods = torch.stack([xs_, ys_], dim=-1)\n\n\n\n#             if sample_correspondence_pairs:\n#                 correspondences_2d3d = example['correspondences_2d3d'].squeeze()\n#                 if(correspondences_2d3d.size(0) > 256):\n#                     choice = np.random.permutation(\n#                         correspondences_2d3d.size(0))[:256]\n#                     correspondences_2d3d = correspondences_2d3d[choice]\n\n#                 src_idx = correspondences_2d3d[:, 0]\n#                 tgt_idx = correspondences_2d3d[:, 1]\n#             else:\n#                 # src_idx = np.random.permutation(len(xs_))[:256]\n#                 # src_idx = np.random.permutation(len(xs_))[:256]\n#                 _, idx = fps_utils.farthest_point_sampling_withidx(np.stack([xs_.cpu().numpy(\n#                 ), ys_.cpu().numpy(), np.zeros_like(xs_.cpu().numpy())], axis=-1), 256, False)\n#                 # _, idx = fps_utils.farthest_point_sampling_withidx(np.stack([xs_.cpu().numpy(\n#                 # ), ys_.cpu().numpy(), np.zeros_like(xs_.cpu().numpy())], axis=-1), 1024, False)\n#                 # src_idx = np.arange(len(xs_))[::len(xs_)//256]\n#                 src_idx = np.arange(len(xs_))[idx]\n#                 tgt_idx = np.arange(len(model_points))\n\n#             src_pcd, tgt_pcd = lifted_points[src_idx], model_points[tgt_idx]\n#             img_coods = img_coods[src_idx]\n#             src_feats, tgt_feats = descriptors_2d[src_idx], descriptors_3d[tgt_idx]\n\n#             feats_dist = torch.sqrt(square_distance(\n#                 src_feats[None, :, :], tgt_feats[None, :, :], normalised=True)).squeeze(0)\n\n#             _, sel_idx = torch.min(feats_dist, -1)\n#             K = example[\"K\"].cpu().numpy().squeeze()\n#             try:\n#                 # retval, R_pred,t_pred, inliers =cv2.solvePnPRansac(tgt_pcd[sel_idx].cpu().numpy(), img_coods.cpu().numpy().astype(np.float32), K,distCoeffs=np.zeros(4),reprojectionError=1)\n#                 retval, R_pred, t_pred, inliers = cv2.solvePnPRansac(tgt_pcd[sel_idx].cpu().numpy(), img_coods.cpu(\n#                 ).numpy().astype(np.float32), K, distCoeffs=np.zeros(4), reprojectionError=1, iterationsCount=1000)\n\n#                 if inliers is None:\n#                     raise ValueError\n#             except:\n#                 # try:\n#                 print(\"PNP RANSAC reprojectionError threshold =3\")\n#                 retval, R_pred, t_pred, inliers = cv2.solvePnPRansac(tgt_pcd[sel_idx].cpu().numpy(), img_coods.cpu(\n#                 ).numpy().astype(np.float32), K, distCoeffs=np.zeros(4), reprojectionError=3, iterationsCount=1000)\n    \n#             R_pred, _ = cv2.Rodrigues(R_pred)\n#             pose_pred = np.concatenate([R_pred, t_pred], axis=-1)\n#         else:\n#             K = example[\"K\"].cpu().numpy().squeeze()\n#             R_pred = preds_dict['Ti_pred'].G[:,0, :3,:3].squeeze().detach().cpu().numpy()\n#             t_pred = preds_dict['Ti_pred'].G[:,0, :3,3:].squeeze(0).detach().cpu().numpy()\n#             pose_pred= preds_dict['Ti_pred'].G[:,0, :3].squeeze().detach().cpu().numpy()\n# #             print(example['POSECNN_RT'].dtype, example['rendered_RT'].dtype, flush=True)\n# #             R_pred = example['POSECNN_RT'][:,:3,:3].squeeze().detach().cpu().numpy()\n# #             t_pred = example['POSECNN_RT'][:,:3,3:].squeeze(0).detach().cpu().numpy()\n# #             pose_pred= example['POSECNN_RT'][:, :3].squeeze().detach().cpu().numpy()\n\n\n#         # pose_gt = example['RT'].squeeze()[:3].cpu().numpy()\n#         pose_gt = example['original_RT'].squeeze()[:3].cpu().numpy()\n        \n        \n#         ang_err = rotation_angle(pose_gt[:3, :3], R_pred)\n#         trans_err = np.linalg.norm(t_pred-pose_gt[:3, -1:])  # 3x1\n\n#         if self.class_name in ['024_bowl', '036_wood_block', '051_large_clamp', '052_extra_large_clamp', '061_foam_brick']:\n#             self.add_metric(pose_pred, pose_gt, syn=True)\n#         else:\n#             self.add_metric(pose_pred, pose_gt)\n#         self.adds_metric(pose_pred, pose_gt, syn=True)\n#         self.projection_2d(pose_pred, pose_gt, K=linemod_config.linemod_K)\n#         self.cm_degree_5_metric(pose_pred, pose_gt)\n\n#         # self.mask_iou(output, batch)\n\n#         # vis\n#         pc_proj_vis = vis_pointclouds_cv2((pose_gt[:3, :3]@model_points.cpu().numpy(\n#         ).T+pose_gt[:3, -1:]).T, example[\"K\"].cpu().numpy().squeeze(), [480,640])\n#         pc_proj_vis_pred = vis_pointclouds_cv2((pose_pred[:3, :3]@model_points.cpu().numpy(\n#         ).T+pose_pred[:3, -1:]).T, example[\"K\"].cpu().numpy().squeeze(), [ 480, 640])\n\n\n#         if trans_err > 0.5:\n#             print(\"translation err>0.5\")\n#             # cv2.imwrite(f'tmp/{len(self.add)}.png',\n#             #             keypoints_2d_vis[..., ::-1]*255,)\n#             # torch.save( example, f'tmp/{len(self.add)}.pt' )\n\n#         return {\n#             \"ang_err\": ang_err,\n#             \"trans_err\": trans_err,\n#             \"pnp_inliers\": -1,#len(inliers),\n#             \"pc_proj_vis\": pc_proj_vis,\n#             \"pc_proj_vis_pred\": pc_proj_vis_pred,\n#             \"keypoints_2d_vis\": np.zeros_like(pc_proj_vis_pred) #keypoints_2d_vis\n#         }\n"
  },
  {
    "path": "utils/furthest_point_sample.py",
    "content": "\nimport numpy as np\nfrom scipy import spatial\n\n\ndef fragmentation_fps(vertices, num_frags):\n  \"\"\"Fragmentation by the furthest point sampling algorithm.\n\n  The fragment centers are found by iterative selection of the vertex from\n  vertices that is furthest from the already selected vertices. The algorithm\n  starts with the centroid of the object model which is then discarded from the\n  final set of fragment centers.\n\n  A fragment is defined by a set of points on the object model that are the\n  closest to the fragment center.\n\n  Args:\n    vertices: [num_vertices, 3] ndarray with 3D vertices of the object model.\n    num_frags: Number of fragments to define.\n\n  Returns:\n    [num_frags, 3] ndarray with fragment centers and [num_vertices] ndarray\n    storing for each vertex the ID of the assigned fragment.\n  \"\"\"\n  # Start with the origin of the model coordinate system.\n  frag_centers = [np.array([0., 0., 0.])]\n\n  # Calculate distances to the center from all the vertices.\n  nn_index = spatial.cKDTree(frag_centers)\n  nn_dists, _ = nn_index.query(vertices, k=1)\n  center_inds=[]\n  for _ in range(num_frags):\n    # Select the furthest vertex as the next center.\n    new_center_ind = np.argmax(nn_dists)\n    new_center = vertices[new_center_ind]\n    frag_centers.append(vertices[new_center_ind])\n    center_inds.append(new_center_ind)\n\n    # Update the distances to the nearest center.\n    nn_dists[new_center_ind] = -1\n    nn_dists = np.minimum(\n      nn_dists, np.linalg.norm(vertices - new_center, axis=1))\n\n  # Remove the origin.\n  frag_centers.pop(0)\n  frag_centers = np.array(frag_centers)\n\n  # Assign vertices to the fragments.\n  # TODO: This information can be maintained during the FPS algorithm.\n  nn_index = spatial.cKDTree(frag_centers)\n  _, vertex_frag_ids = nn_index.query(vertices, k=1)\n\n#   return frag_centers, vertex_frag_ids\n  return frag_centers, np.array(center_inds), vertex_frag_ids\n\n\nif __name__==\"__main__\":\n    #test\n    pass\n"
  },
  {
    "path": "utils/geometric.py",
    "content": "import numpy as np \n\n\ndef range_to_depth(mask, range, K):\n    '''\n       Transform the range image to depth image\n    '''\n    f=K[0,0]\n    cx=K[0,2]\n    cy=K[1,2]\n\n    ys_, xs_=np.nonzero(mask)\n    rngs=range[ys_,xs_]\n    # xs,ys=np.asarray(xs,np.float32),np.asarray(ys,np.float32)\n    xs,ys=np.asarray(xs_,np.float32)+0.5,np.asarray(ys_,np.float32)+0.5\n\n    Zs=f*rngs/( f**2 + (cx-xs)**2 + (cy-ys)**2 )**0.5\n    depth = np.zeros_like(range)\n    depth[ys_,xs_] = Zs\n    return  depth\n\ndef mask_depth_to_point_cloud(mask,depth,K):\n    '''\n        lift the depth under the mask to 3D point clouds\n    '''\n    ys, xs=np.nonzero(mask)\n    dpts=depth[ys,xs]\n    # xs,ys=np.asarray(xs,np.float32),np.asarray(ys,np.float32)\n    xs,ys=np.asarray(xs,np.float32)+0.5,np.asarray(ys,np.float32)+0.5\n    xys=np.concatenate([xs[:,None],ys[:,None]],1)\n    xys*=dpts[:,None]\n    xyds=np.concatenate([xys,dpts[:,None]],1)\n    pts=np.matmul(xyds,np.linalg.inv(K).transpose())\n    return pts.astype(np.float32), np.stack([xs,ys], axis=-1 )\n\ndef chordal_distance(R1,R2):\n    return np.sqrt(np.sum((R1-R2)*(R1-R2))) \n\ndef rotation_angle(R1, R2):\n    return 2*np.arcsin( chordal_distance(R1,R2)/np.sqrt(8) )\n\ndef render_pointcloud(pc, T, K, render_image_size):\n        \"\"\"\n        Args:\n            T: (B,3,4) or (B,4,4)\n            K: (B,3,3)\n            render_image_size (tuple): (h,w)\n            near (float, optional):  Defaults to 0.1.\n            far (int, optional): Defaults to 6.\n            mode: 'bilinear' or 'neareast'\n        \"\"\"\n\n        B=T.shape[0]\n\n        # T = self.cam_opencv2pytch3d.to(device=T.device)@T\n\n        ## X_cam = X_world R + t\n        # R = T[...,:3,:3].transpose(-1,-2)\n        R = T[...,:3,:3].transpose( [0,2,1] )\n        t = T[...,:3,3]\n\n        #render depths\n        # vert_depths= (self.verts@R+t).squeeze(0)[...,2:]\n        X_cam= (pc@R+t)#.squeeze(0)\n\n        x=X_cam@K.transpose([0,2,1])  #BxNx3\n        depth = x[...,-1]\n        x = x/x[...,-1:]\n\n        out = np.zeros([1,1, *render_image_size], dtype=R.dtype)\n        out[:, :, \n            np.round(x[0, :, 1]).astype(np.int64).clip(0, out.shape[2]-1),\n            np.round(x[0, :, 0]).astype(np.int64).clip(0, out.shape[3]-1)] = depth \n\n        return out #1x1xHxW"
  },
  {
    "path": "utils/img_utils.py",
    "content": "import torch\nfrom matplotlib import cm\nimport matplotlib.pyplot as plt\nimport matplotlib.patches as patches\nimport numpy as np\nimport cv2\nfrom PIL import Image\n\ndef read_depth(path):\n    if (path[-3:] == 'dpt'):\n        with open(path) as f:\n            h,w = np.fromfile(f,dtype=np.uint32,count=2)\n            data = np.fromfile(f,dtype=np.uint16,count=w*h)\n            depth = data.reshape((h,w))\n    else:\n        depth = np.asarray(Image.open(path)).copy()\n    return depth\n\ndef unnormalize_img(img, mean, std, in_gpu=True):\n    \"\"\"\n    img: [3, h, w]\n    \"\"\"\n    img = img.detach().cpu().clone()\n    # img = img / 255.\n    img *= torch.tensor(std).view(3, 1, 1)\n    img += torch.tensor(mean).view(3, 1, 1)\n    min_v = torch.min(img)\n    img = (img - min_v) / (torch.max(img) - min_v)\n    return img\n\n\ndef draw_seg_th(seg, num_cls=-1):\n    \"\"\"\n    seg: [h, w]\n    \"\"\"\n    r = seg.clone()\n    g = seg.clone()\n    b = seg.clone()\n    num_cls = len(colors) if num_cls == -1 else num_cls\n    seg_colors = 1 - colors[:, 0, 0]\n    for l in range(num_cls):\n        inds = (seg == l)\n        r[inds] = int(seg_colors[l][0])\n        g[inds] = int(seg_colors[l][1])\n        b[inds] = int(seg_colors[l][2])\n    seg = torch.stack([r, g, b], dim=0).float() / 255.\n    return seg\n\n\ndef draw_seg_prob_th(seg_prob):\n    \"\"\"\n    seg_prob: [num_cls, h, w]\n    \"\"\"\n    num_cls = seg_prob.shape[0]\n    seg = torch.argmax(seg_prob, dim=0).long()\n    return draw_seg_th(seg, num_cls)\n\n\ndef draw_vertex_th(vertex):\n    \"\"\"\n    vertex: [h, w]\n    \"\"\"\n    min_ver = torch.min(vertex)\n    max_ver = torch.max(vertex)\n    vertex = (vertex - min_ver) / (max_ver - min_ver)\n    vertex = cmap(vertex.detach().cpu().numpy())[..., :3]\n    return torch.tensor(vertex).permute(2, 0, 1)\n\n\ndef visualize_coco_bbox(img, boxes):\n    \"\"\"\n    img: [h, w, 3]\n    boxes: [n, 4], [[x, y, x_max, y_max]]\n    \"\"\"\n    _, ax = plt.subplots(1)\n    ax.imshow(img)\n    n = len(boxes)\n    for ni in range(n):\n        x, y, x_max, y_max = boxes[ni]\n        ax.add_patch(patches.Polygon(xy=[[x, y], [x, y_max], [x_max, y_max], [x_max, y]], fill=False, linewidth=1, edgecolor='r'))\n    plt.show()\n\n\ndef visualize_heatmap(img, hm):\n    \"\"\"\n    img: [h, w, 3]\n    hm: [c, h, w]\n    \"\"\"\n    hm = np.max(hm, axis=0)\n    h, w = hm.shape[:2]\n    img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_LINEAR)\n    hm = np.array([255, 255, 255]) - (hm.reshape(h, w, 1) * colors[0]).astype(np.uint8)\n    ratio = 0.5\n    blend = (img * ratio + hm * (1 - ratio)).astype(np.uint8)\n    _, (ax1, ax2) = plt.subplots(1, 2)\n    ax1.imshow(img)\n    ax2.imshow(blend)\n    plt.show()\n\n\ndef visualize_coco_img_mask(img, mask):\n    _, (ax1, ax2) = plt.subplots(1, 2)\n    ax1.imshow(img)\n    ax2.imshow(mask)\n    plt.show()\n\n\ndef visualize_color_aug(orig_img, aug_img):\n    _, (ax1, ax2) = plt.subplots(1, 2)\n    ax1.imshow(orig_img[:, :, [2, 1, 0]])\n    ax2.imshow(aug_img[:, :, [2, 1, 0]])\n    plt.show()\n\n\ndef visualize_coco_ann(coco, img, ann):\n    plt.imshow(img)\n    coco.showAnns(ann)\n    plt.show()\n\n\ndef bgr_to_rgb(img):\n    return img[:, :, [2, 1, 0]]\n\n\ncmap = cm.get_cmap()\ncolor_list = np.array(\n    [\n        0.000, 0.447, 0.741,\n        0.850, 0.325, 0.098,\n        0.929, 0.694, 0.125,\n        0.494, 0.184, 0.556,\n        0.466, 0.674, 0.188,\n        0.301, 0.745, 0.933,\n        0.635, 0.078, 0.184,\n        0.300, 0.300, 0.300,\n        0.600, 0.600, 0.600,\n        1.000, 0.000, 0.000,\n        1.000, 0.500, 0.000,\n        0.749, 0.749, 0.000,\n        0.000, 1.000, 0.000,\n        0.000, 0.000, 1.000,\n        0.667, 0.000, 1.000,\n        0.333, 0.333, 0.000,\n        0.333, 0.667, 0.000,\n        0.333, 1.000, 0.000,\n        0.667, 0.333, 0.000,\n        0.667, 0.667, 0.000,\n        0.667, 1.000, 0.000,\n        1.000, 0.333, 0.000,\n        1.000, 0.667, 0.000,\n        1.000, 1.000, 0.000,\n        0.000, 0.333, 0.500,\n        0.000, 0.667, 0.500,\n        0.000, 1.000, 0.500,\n        0.333, 0.000, 0.500,\n        0.333, 0.333, 0.500,\n        0.333, 0.667, 0.500,\n        0.333, 1.000, 0.500,\n        0.667, 0.000, 0.500,\n        0.667, 0.333, 0.500,\n        0.667, 0.667, 0.500,\n        0.667, 1.000, 0.500,\n        1.000, 0.000, 0.500,\n        1.000, 0.333, 0.500,\n        1.000, 0.667, 0.500,\n        1.000, 1.000, 0.500,\n        0.000, 0.333, 1.000,\n        0.000, 0.667, 1.000,\n        0.000, 1.000, 1.000,\n        0.333, 0.000, 1.000,\n        0.333, 0.333, 1.000,\n        0.333, 0.667, 1.000,\n        0.333, 1.000, 1.000,\n        0.667, 0.000, 1.000,\n        0.667, 0.333, 1.000,\n        0.667, 0.667, 1.000,\n        0.667, 1.000, 1.000,\n        1.000, 0.000, 1.000,\n        1.000, 0.333, 1.000,\n        1.000, 0.667, 1.000,\n        0.167, 0.000, 0.000,\n        0.333, 0.000, 0.000,\n        0.500, 0.000, 0.000,\n        0.667, 0.000, 0.000,\n        0.833, 0.000, 0.000,\n        1.000, 0.000, 0.000,\n        0.000, 0.167, 0.000,\n        0.000, 0.333, 0.000,\n        0.000, 0.500, 0.000,\n        0.000, 0.667, 0.000,\n        0.000, 0.833, 0.000,\n        0.000, 1.000, 0.000,\n        0.000, 0.000, 0.167,\n        0.000, 0.000, 0.333,\n        0.000, 0.000, 0.500,\n        0.000, 0.000, 0.667,\n        0.000, 0.000, 0.833,\n        0.000, 0.000, 1.000,\n        0.000, 0.000, 0.000,\n        0.143, 0.143, 0.143,\n        0.286, 0.286, 0.286,\n        0.429, 0.429, 0.429,\n        0.571, 0.571, 0.571,\n        0.714, 0.714, 0.714,\n        0.857, 0.857, 0.857,\n        1.000, 1.000, 1.000,\n        0.50, 0.5, 0\n    ]\n).astype(np.float32)\ncolors = color_list.reshape((-1, 3)) * 255\ncolors = np.array(colors, dtype=np.uint8).reshape(len(colors), 1, 1, 3)\n"
  },
  {
    "path": "utils/log_tool.py",
    "content": "import numpy as np\nfrom tensorboardX import SummaryWriter\nimport json\nfrom pathlib import Path\nimport logging\n\n\n\ndef _flat_nested_json_dict(json_dict, flatted, sep=\".\", start=\"\"):\n    for k, v in json_dict.items():\n        if isinstance(v, dict):\n            _flat_nested_json_dict(v, flatted, sep, start + sep + str(k))\n        else:\n            flatted[start + sep + str(k)] = v\n\n\ndef flat_nested_json_dict(json_dict, sep=\".\") -> dict:\n    \"\"\"flat a nested json-like dict. this function make shadow copy.\n    \"\"\"\n    flatted = {}\n    for k, v in json_dict.items():\n        if isinstance(v, dict):\n            _flat_nested_json_dict(v, flatted, sep, str(k))\n        else:\n            flatted[str(k)] = v\n    return flatted\n\n\ndef metric_to_str(metrics, sep='.'):\n    flatted_metrics = flat_nested_json_dict(metrics, sep)\n    metrics_str_list = []\n    for k, v in flatted_metrics.items():\n        if isinstance(v, float):\n            metrics_str_list.append(f\"{k}={v:.5}\")\n        elif isinstance(v, (list, tuple)):\n            if v and isinstance(v[0], float):\n                v_str = ', '.join([f\"{e:.5}\" for e in v])\n                metrics_str_list.append(f\"{k}=[{v_str}]\")\n            else:\n                metrics_str_list.append(f\"{k}={v}\")\n        else:\n            metrics_str_list.append(f\"{k}={v}\")\n    return ', '.join(metrics_str_list)\n\n\nclass SimpleModelLog:\n    \"\"\"For simple log.\n    generate 4 kinds of log: \n    1. simple log.txt, all metric dicts are flattened to produce\n    readable results.\n    2. TensorBoard scalars and texts\n    3. multi-line json file log.json.lst\n    4. tensorboard_scalars.json, all scalars are stored in this file\n        in tensorboard json format.\n    \"\"\"\n\n    def __init__(self, model_dir, disable=False):\n        self.model_dir = Path(model_dir)\n        self.log_file = None\n        self.log_mjson_file = None\n        self.summary_writter = None\n        self.metrics = []\n        self._text_current_gstep = -1\n        self._tb_texts = []\n        self.disable = disable\n\n        logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n        self.logger = logging.getLogger(__name__)\n\n    def open(self):\n        if self.disable:\n            return self\n        model_dir = self.model_dir\n        assert model_dir.exists()\n        summary_dir = model_dir / 'summary'\n        summary_dir.mkdir(parents=True, exist_ok=True)\n\n        log_mjson_file_path = model_dir / f'log.json.lst'\n        if log_mjson_file_path.exists():\n            with open(log_mjson_file_path, 'r') as f:\n                for line in f.readlines():\n                    self.metrics.append(json.loads(line))\n        log_file_path = model_dir / f'log.txt'\n        self.log_mjson_file = open(log_mjson_file_path, 'a')\n        self.log_file = open(log_file_path, 'a')\n        self.summary_writter = SummaryWriter(str(summary_dir))\n        return self\n\n    def close(self):\n        if self.disable:\n            return \n        assert self.summary_writter is not None\n        self.log_mjson_file.close()\n        self.log_file.close()\n        tb_json_path = str(self.model_dir / \"tensorboard_scalars.json\")\n        self.summary_writter.export_scalars_to_json(tb_json_path)\n        self.summary_writter.close()\n        self.log_mjson_file = None\n        self.log_file = None\n        self.summary_writter = None\n\n    def log_text(self, text, step, tag=\"regular log\"):\n        if self.disable:\n            return \n        \"\"\"This function only add text to log.txt and tensorboard texts\n        \"\"\"\n        print(text,flush=True)\n        print(text, file=self.log_file,flush=True)\n        if step > self._text_current_gstep and self._text_current_gstep != -1:\n            total_text = '\\n'.join(self._tb_texts)\n            self.summary_writter.add_text(tag, total_text, global_step=step)\n            self._tb_texts = []\n            self._text_current_gstep = step\n        else:\n            self._tb_texts.append(text)\n        if self._text_current_gstep == -1:\n            self._text_current_gstep = step\n\n    def log_metrics(self, metrics: dict, step):\n        if self.disable:\n            return \n        flatted_summarys = flat_nested_json_dict(metrics, \"/\")\n        for k, v in flatted_summarys.items():\n            if isinstance(v, (list, tuple)):\n                if any([isinstance(e, str) for e in v]):\n                    continue\n                v_dict = {str(i): e for i, e in enumerate(v)}\n                for k1, v1 in v_dict.items():\n                    self.summary_writter.add_scalar(k + \"/\" + k1, v1, step)\n            else:\n                if isinstance(v, str):\n                    continue\n                self.summary_writter.add_scalar(k, v, step)\n        log_str = metric_to_str(metrics)\n        # print(log_str, flush=True)\n        self.logger.info(log_str)\n        print(log_str, file=self.log_file, flush=True)\n        print(json.dumps(metrics), file=self.log_mjson_file, flush=True)\n\n    def log_images(self, images: dict, step, prefix=''):\n        if self.disable:\n            return \n        for k, v in images.items():\n            self.summary_writter.add_images(prefix+str(k), v, step)\n            print(f\"Summarize images {k}\",flush=True)\n\n    def log_histograms(self, vals: dict, step, prefix=''):\n        if self.disable:\n            return \n        for k, v in vals.items():\n            self.summary_writter.add_histogram(prefix+str(k), v, step)\n            print(f\"Summarize histograms {k}\",flush=True)\n\n    # def log_distributions(self, vals: dict, step, prefix=''):\n    #     if self.disable:\n    #         return \n    #     for k, v in vals.items():\n    #         self.summary_writter.histogram(prefix+str(k), v, step)\n    #         print(f\"Summarize histograms {k}\",flush=True)"
  },
  {
    "path": "utils/pose_utils.py",
    "content": "\"\"\"\nCopyright (C) 2018 NVIDIA Corporation.  All rights reserved.\nLicensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).\n\"\"\"\n\nimport torch\nfrom torch.nn import Module\nfrom torch.autograd import Variable\nfrom torch.nn.functional import pad\nimport numpy as np\nimport scipy.linalg as slin\nimport math\nimport transforms3d.quaternions as txq\nimport transforms3d.euler as txe\n# see for formulas:\n# https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-801-machine-vision-fall-2004/readings/quaternions.pdf\n# and \"Quaternion and Rotation\" - Yan-Bin Jia, September 18, 2016\n#from IPython.core.debugger import set_trace\n\n# PYTORCH\n\ndef pose_padding(P):\n    \"\"\"\n    Padding 3x4 SE3 to 4x4\n    Args:\n        P ([type]): [description]\n        dim ([type]): [description]\n    \"\"\"\n    assert (P.shape[-2] == 3 and P.shape[-1] == 4) or (P.shape[-2] == 2 and P.shape[-1] == 3)\n\n    pad = torch.zeros_like(P[...,:1,:] )\n    pad[...,-1]=1\n    return torch.cat([P,pad], dim=-2)\n    \ndef vdot(v1, v2):\n    \"\"\"\n    Dot product along the dim=1\n    :param v1: N x d\n    :param v2: N x d\n    :return: N x 1\n    \"\"\"\n    out = torch.mul(v1, v2)\n    out = torch.sum(out, 1)\n    return out\n\n\ndef normalize(x, p=2, dim=0):\n    \"\"\"\n    Divides a tensor along a certain dim by the Lp norm\n    :param x: \n    :param p: Lp norm\n    :param dim: Dimension to normalize along\n    :return: \n    \"\"\"\n    xn = x.norm(p=p, dim=dim)\n    x = x / xn.unsqueeze(dim=dim)\n    return x\n\n\ndef qmult(q1, q2):\n    \"\"\"\n    Multiply 2 quaternions\n    :param q1: Tensor N x 4\n    :param q2: Tensor N x 4\n    :return: quaternion product, Tensor N x 4\n    \"\"\"\n    q1s, q1v = q1[:, :1], q1[:, 1:]\n    q2s, q2v = q2[:, :1], q2[:, 1:]\n\n    qs = q1s*q2s - vdot(q1v, q2v)\n    qv = q1v.mul(q2s.expand_as(q1v)) + q2v.mul(q1s.expand_as(q2v)) +\\\n        torch.cross(q1v, q2v, dim=1)\n    q = torch.cat((qs, qv), dim=1)\n\n    # normalize\n    q = normalize(q, dim=1)\n\n    return q\n\n\ndef qinv(q):\n    \"\"\"\n    Inverts quaternions\n    :param q: N x 4\n    :return: q*: N x 4 \n    \"\"\"\n    q_inv = torch.cat((q[:, :1], -q[:, 1:]), dim=1)\n    return q_inv\n\n\ndef qexp_t(q):\n    \"\"\"\n    Applies exponential map to log quaternion\n    :param q: N x 3\n    :return: N x 4\n    \"\"\"\n    n = torch.norm(q, p=2, dim=1, keepdim=True)\n    n = torch.clamp(n, min=1e-8)\n    q = q * torch.sin(n)\n    q = q / n\n    q = torch.cat((torch.cos(n), q), dim=1)\n    return q\n\n\ndef qlog_t(q):\n    \"\"\"\n    Applies the log map to a quaternion\n    :param q: N x 4\n    :return: N x 3\n    \"\"\"\n    n = torch.norm(q[:, 1:], p=2, dim=1, keepdim=True)\n    n = torch.clamp(n, min=1e-8)\n    q = q[:, 1:] * torch.acos(torch.clamp(q[:, :1], min=-1.0, max=1.0))\n    q = q / n\n    return q\n\n\ndef qexp_t_safe(q):\n    \"\"\"\n    Applies exponential map to log quaternion (safe implementation that does not\n    maintain gradient flow)\n    :param q: N x 3\n    :return: N x 4\n    \"\"\"\n    q = torch.from_numpy(np.asarray([qexp(qq) for qq in q.numpy()],\n                                    dtype=np.float32))\n    return q\n\n\ndef qlog_t_safe(q):\n    \"\"\"\n    Applies the log map to a quaternion (safe implementation that does not\n    maintain gradient flow)\n    :param q: N x 4\n    :return: N x 3\n    \"\"\"\n    q = torch.from_numpy(np.asarray([qlog(qq) for qq in q.numpy()],\n                                    dtype=np.float32))\n    return q\n\n\ndef rotate_vec_by_q(t, q):\n    \"\"\"\n    rotates vector t by quaternion q\n    :param t: vector, Tensor N x 3\n    :param q: quaternion, Tensor N x 4\n    :return: t rotated by q: t' = t + 2*qs*(qv x t) + 2*qv x (qv x r) \n    \"\"\"\n    qs, qv = q[:, :1], q[:, 1:]\n    b = torch.cross(qv, t, dim=1)\n    c = 2 * torch.cross(qv, b, dim=1)\n    b = 2 * b.mul(qs.expand_as(b))\n    tq = t + b + c\n    return tq\n\n\ndef compose_pose_quaternion(p1, p2):\n    \"\"\"\n    pyTorch implementation\n    :param p1: input pose, Tensor N x 7 \n    :param p2: pose to apply, Tensor N x 7\n    :return: output pose, Tensor N x 7\n    all poses are translation + quaternion\n    #!comments: first apply p2 and then p1 !!\n    \"\"\"\n    p1t, p1q = p1[:, :3], p1[:, 3:]\n    p2t, p2q = p2[:, :3], p2[:, 3:]\n    q = qmult(p1q, p2q)\n    t = p1t + rotate_vec_by_q(p2t, p1q)\n    return torch.cat((t, q), dim=1)\n\n\ndef invert_pose_quaternion(p):\n    \"\"\"\n    inverts the pose\n    :param p: pose, Tensor N x 7\n    :return: inverted pose\n    \"\"\"\n    t, q = p[:, :3], p[:, 3:]\n    q_inv = qinv(q)\n    tinv = -rotate_vec_by_q(t, q_inv)\n    return torch.cat((tinv, q_inv), dim=1)\n\n\ndef calc_vo(p0, p1):\n    \"\"\"\n    calculates VO (in the p0 frame) from 2 poses\n    :param p0: N x 7\n    :param p1: N x 7\n    \"\"\"\n    # assert p0.shape==p1.shape\n    return compose_pose_quaternion(invert_pose_quaternion(p0), p1)\n\n\ndef calc_vo_logq(p0, p1):\n    \"\"\"\n    VO (in the p0 frame) (logq)\n    :param p0: N x 6\n    :param p1: N x 6\n    :return: N-1 x 6\n    \"\"\"\n    q0 = qexp_t(p0[:, 3:])\n    q1 = qexp_t(p1[:, 3:])\n    vos = calc_vo(torch.cat((p0[:, :3], q0), dim=1), torch.cat((p1[:, :3], q1),\n                                                               dim=1))\n    vos_q = qlog_t(vos[:, 3:])\n    return torch.cat((vos[:, :3], vos_q), dim=1)\n\n\ndef calc_vo_relative(p0, p1):\n    \"\"\"\n    calculates VO (in the world frame) from 2 poses\n    :param p0: N x 7\n    :param p1: N x 7\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    vos_q = qmult(qinv(p0[:, 3:]), p1[:, 3:])\n    return torch.cat((vos_t, vos_q), dim=1)\n\n\ndef calc_vo_relative_logq(p0, p1):\n    \"\"\"\n    Calculates VO (in the world frame) from 2 poses (log q)\n    :param p0: N x 6\n    :param p1: N x 6\n    :return:\n    \"\"\"\n    q0 = qexp_t(p0[:, 3:])\n    q1 = qexp_t(p1[:, 3:])\n    vos = calc_vo_relative(torch.cat((p0[:, :3], q0), dim=1),\n                           torch.cat((p1[:, :3], q1), dim=1))\n    vos_q = qlog_t(vos[:, 3:])\n    return torch.cat((vos[:, :3], vos_q), dim=1)\n\n\ndef calc_vo_relative_logq_safe(p0, p1):\n    \"\"\"\n    Calculates VO (in the world frame) from 2 poses (log q) through numpy fns\n    :param p0: N x 6\n    :param p1: N x 6\n    :return:\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    q0 = qexp_t_safe(p0[:, 3:])\n    q1 = qexp_t_safe(p1[:, 3:])\n    vos_q = qmult(qinv(q0), q1)\n    vos_q = qlog_t_safe(vos_q)\n    return torch.cat((vos_t, vos_q), dim=1)\n\n\ndef calc_vo_logq_safe(p0, p1):\n    \"\"\"\n    VO in the p0 frame using numpy fns\n    :param p0:\n    :param p1:\n    :return:\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    q0 = qexp_t_safe(p0[:, 3:])\n    q1 = qexp_t_safe(p1[:, 3:])\n    vos_t = rotate_vec_by_q(vos_t, qinv(q0))\n    vos_q = qmult(qinv(q0), q1)\n    vos_q = qlog_t_safe(vos_q)\n    return torch.cat((vos_t, vos_q), dim=1)\n\n\ndef calc_vos_simple(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [p[i+1].unsqueeze(0) - p[i].unsqueeze(0)\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n\n    return vos\n\n\ndef calc_vos(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (in the p0 frame)\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_logq(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_relative(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (in the world frame)\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_relative_logq(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_safe(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_logq_safe(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_safe_fc(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (fully connected)\n    :param poses: N x T x 7\n    :return: N x TC2 x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = []\n        for i in range(p.size(0)):\n            for j in range(i+1, p.size(0)):\n                pvos.append(calc_vo_logq_safe(\n                    p[i].unsqueeze(0), p[j].unsqueeze(0)))\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n# NUMPY\n\n\ndef qlog(q):\n    \"\"\"\n    Applies logarithm map to q\n    :param q: (4,)\n    :return: (3,)\n    \"\"\"\n    if all(q[1:] == 0):\n        q = np.zeros(3)\n    else:\n        q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:])\n    return q\n\n\ndef qexp(q):\n    \"\"\"\n    Applies the exponential map to q\n    :param q: (3,)\n    :return: (4,)\n    \"\"\"\n    n = np.linalg.norm(q)\n    q = np.hstack((np.cos(n), np.sinc(n/np.pi)*q))\n    return q\n\n\ndef process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):\n    \"\"\"\n    processes the 1x12 raw pose from dataset by aligning and then normalizing\n    :param poses_in: N x 12\n    :param mean_t: 3\n    :param std_t: 3\n    :param align_R: 3 x 3\n    :param align_t: 3\n    :param align_s: 1\n    :return: processed poses (translation + quaternion) N x 7\n    \"\"\"\n    poses_out = np.zeros((len(poses_in), 6))\n    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]\n\n    # align\n    for i in range(len(poses_out)):\n        R = poses_in[i].reshape((3, 4))[:3, :3]\n        q = txq.mat2quat(np.dot(align_R, R))\n        q *= np.sign(q[0])  # constrain to hemisphere\n        q = qlog(q)\n        poses_out[i, 3:] = q\n        t = poses_out[i, :3] - align_t\n        poses_out[i, :3] = align_s * \\\n            np.dot(align_R, t[:, np.newaxis]).squeeze()\n\n    # normalize translation\n    poses_out[:, :3] -= mean_t\n    poses_out[:, :3] /= std_t\n    return poses_out\n\n\ndef log_quaternion_angular_error(q1, q2):\n    return quaternion_angular_error(qexp(q1), qexp(q2))\n\n\ndef quaternion_angular_error(q1, q2):\n    \"\"\"\n    angular error between two quaternions\n    :param q1: (4, )\n    :param q2: (4, )\n    :return:\n    \"\"\"\n    d = abs(np.dot(q1, q2))\n    d = min(1.0, max(-1.0, d))\n    theta = 2 * np.arccos(d) * 180 / np.pi\n    return theta\n\n\ndef skew(x):\n    \"\"\"\n    returns skew symmetric matrix from vector\n    :param x: 3 x 1\n    :return:\n    \"\"\"\n    s = np.asarray([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])\n    return s\n\n\ndef dpq_q(p):\n    \"\"\"\n    returns the jacobian of quaternion product pq w.r.t. q\n    :param p: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = p[0]\n    J[0, 1:] = -p[1:].squeeze()\n    J[1:, 0] = p[1:].squeeze()\n    J[1:, 1:] = p[0] * np.eye(3) + skew(p[1:])\n    return J\n\n\ndef dpsq_q(p):\n    \"\"\"\n    returns the jacobian of quaternion product (p*)q w.r.t. q\n    :param p: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = p[0]\n    J[0, 1:] = -p[1:].squeeze()\n    J[1:, 0] = -p[1:].squeeze()\n    J[1:, 1:] = p[0] * np.eye(3) - skew(p[1:])\n    return J\n\n\ndef dpsq_p(q):\n    \"\"\"\n    returns the jacobian of quaternion product (p*)q w.r.t. p\n    :param q: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = q[0]\n    J[0, 1:] = q[1:].squeeze()\n    J[1:, 0] = q[1:].squeeze()\n    J[1:, 1:] = -q[0] * np.eye(3) + skew(q[1:])\n    return J\n\n\ndef dqstq_q(q, t):\n    \"\"\"\n    jacobian of q* t q w.r.t. q\n    :param q: 4 x 1\n    :param t: 3 x 1\n    :return: 3 x 4\n    \"\"\"\n    J = np.zeros((3, 4))\n    J[:, :1] = q[0]*t - np.cross(q[1:], t, axis=0)\n    J[:, 1:] = -np.dot(t, q[1:].T) + np.dot(t.T, q[1:])*np.eye(3) + \\\n        np.dot(q[1:], t.T) + q[0]*skew(t)\n    J *= 2\n    return J\n\n\ndef dqstq_t(q):\n    \"\"\"\n    jacobian of q* t q w.r.t. t\n    :param q: 4 x 1\n    :return: 3 x 3\n    \"\"\"\n    J = (q[0]*q[0] - np.dot(q[1:].T, q[1:])) * np.eye(3) + 2*np.dot(q[1:], q[1:].T) -\\\n        2*q[0]*skew(q[1:])\n    return J\n\n\ndef m_rot(x):\n    \"\"\"\n    returns Jacobian of exponential map w.r.t. manifold increment\n    :param x: part of state vector affected by increment, 4 x 1\n    :return: 4 x 3\n    \"\"\"\n    # jacobian of full q wrt qm (quaternion update on manifold),\n    # evaluated at qv = (0, 0, 0)\n    # full q is derived using either the exponential map or q0 = sqrt(1-qm^2)\n    jm = np.vstack((np.zeros((1, 3)), np.eye(3)))  # 4 x 3\n    m = np.dot(dpq_q(p=x), jm)\n    return m\n\n\nclass PoseGraph:\n    def __init__(self):\n        \"\"\"\n        implements pose graph optimization from\n        \"Hybrid Hessians for Optimization of Pose Graphs\" - Y. LeCun et al\n        and \"A Tutorial on Graph-Based SLAM\" - W. Burgard et al\n        \"\"\"\n        self.N = 0\n        self.z = np.zeros((0, 0))\n\n    def jacobian(self, L_ax, L_aq, L_rx, L_rq):\n        # 6 because updates for rotation are on manifold\n        J = np.zeros((0, 6*self.N))\n\n        # unary constraints\n        for i in range(self.N):\n            # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            jt[:, 6*i: 6*i+3] = np.eye(3)\n            J = np.vstack((J, np.dot(L_ax, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            jr[:, 6*i+3: 6*i+6] = m_rot(x=self.z[7*i+3: 7*i+7])\n            J = np.vstack((J, np.dot(L_aq, jr)))\n\n        # pairwise constraints\n        for i in range(self.N-1):\n                # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            dt = dqstq_t(q=self.z[7*i+3: 7*i+7])\n            # dt = np.eye(3)\n            jt[:, 6*i: 6*i+3] = -dt\n            jt[:, 6*(i+1): 6*(i+1)+3] = dt\n            # m = m_rot(x=self.z[7*i+3 : 7*i+7])\n            # a = dqstq_q(q=self.z[7*i+3 : 7*i+7],\n            #             t=self.z[7*(i+1) : 7*(i+1)+3]-self.z[7*i : 7*i+3])\n            # jt[:, 6*i+3 : 6*i+6] = np.dot(a, m)\n            J = np.vstack((J, np.dot(L_rx, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            m = m_rot(x=self.z[7*i+3: 7*i+7])\n            a = dpsq_p(q=self.z[7*(i+1)+3: 7*(i+1)+7])\n            jr[:, 6*i+3: 6*i+6] = np.dot(a, m)\n            m = m_rot(x=self.z[7*(i+1)+3: 7*(i+1)+7])\n            b = dpsq_q(p=self.z[7*i+3: 7*i+7])\n            jr[:, 6*(i+1)+3: 6*(i+1)+6] = np.dot(b, m)\n            J = np.vstack((J, np.dot(L_rq, jr)))\n\n        return J\n\n    def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):\n        \"\"\"\n        computes the residuals\n        :param poses: N x 7\n        :param vos: (N-1) x 7\n        :param L_ax: 3 x 3\n        :param L_aq: 4 x 4\n        :param L_rx: 3 x 3\n        :param L_rq: 4 x 4\n        :return:\n        \"\"\"\n        r = np.zeros((0, 1))\n\n        # unary residuals\n        L = np.zeros((7, 7))\n        L[:3, :3] = L_ax\n        L[3:, 3:] = L_aq\n        for i in range(self.N):\n            rr = self.z[7*i: 7*(i+1)] - np.reshape(poses[i], (-1, 1))\n            r = np.vstack((r, np.dot(L, rr)))\n\n        # pairwise residuals\n        for i in range(self.N-1):\n            # translation residual\n            v = self.z[7*(i+1):7*(i+1)+3, 0]-self.z[7*i:7*i+3, 0]\n            q = txq.qinverse(self.z[7*i+3:7*i+7, 0])\n            rt = txq.rotate_vector(v, q)\n            rt = rt[:, np.newaxis] - vos[i, :3].reshape((-1, 1))\n            # rt = self.z[7*(i+1) : 7*(i+1)+3] - self.z[7*i : 7*i+3] - \\\n            #     vos[i, :3].reshape((-1, 1))\n            r = np.vstack((r, np.dot(L_rx, rt)))\n\n            # rotation residual\n            q0 = self.z[7*i+3: 7*i+7].squeeze()\n            q1 = self.z[7*(i+1)+3: 7*(i+1)+7].squeeze()\n            qvo = txq.qmult(txq.qinverse(q0), q1).reshape((-1, 1))\n            rq = qvo - vos[i, 3:].reshape((-1, 1))\n            r = np.vstack((r, np.dot(L_rq, rq)))\n\n        return r\n\n    def update_on_manifold(self, x):\n        \"\"\"\n        Updates the state vector on manifold\n        :param x: manifold increment, column vector\n        :return:\n        \"\"\"\n        for i in range(self.N):\n            # update translation\n            t = x[6*i: 6*i+3]\n            self.z[7*i: 7*i+3] += t\n\n            # update rotation\n            qm = x[6*i+3: 6*i+6]  # quaternion on the manifold\n            dq = np.zeros(4)\n            # method in Burgard paper\n            # dq[1:] = qm.squeeze()\n            # dq[0] = math.sqrt(1 - sum(np.square(qm)))  # incremental quaternion\n            # method of exponential map\n            n = np.linalg.norm(qm)\n            dq[0] = math.cos(n)\n            dq[1:] = np.sinc(n/np.pi) * qm.squeeze()\n            q = self.z[7*i+3: 7*i+7].squeeze()\n            q = txq.qmult(q, dq).reshape((-1, 1))\n            self.z[7*i+3: 7*i+7] = q\n\n    def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):\n        \"\"\"\n        run PGO, with init = poses\n        :param poses:\n        :param vos:\n        :param sax: sigma for absolute translation\n        :param saq: sigma for absolute rotation\n        :param srx: sigma for relative translation\n        :param srq: sigma for relative rotation\n        :param n_iters:\n        :return:\n        \"\"\"\n        self.N = len(poses)\n        # init state vector with the predicted poses\n        self.z = np.reshape(poses.copy(), (-1, 1))\n\n        # construct the information matrices\n        L_ax = np.linalg.cholesky(np.eye(3) / sax)\n        L_aq = np.linalg.cholesky(np.eye(4) / saq)\n        L_rx = np.linalg.cholesky(np.eye(3) / srx)\n        L_rq = np.linalg.cholesky(np.eye(4) / srq)\n\n        for n_iter in range(n_iters):\n            J = self.jacobian(L_ax.T, L_aq.T, L_rx.T, L_rq.T)\n            r = self.residuals(poses.copy(), vos.copy(), L_ax.T, L_aq.T, L_rx.T,\n                               L_rq.T)\n            H = np.dot(J.T, J)  # hessian\n            b = np.dot(J.T, r)  # residuals\n\n            # solve Hx = -b for x\n            R = slin.cholesky(H)  # H = R' R\n            y = slin.solve_triangular(R.T, -b)\n            x = slin.solve_triangular(R, y)\n\n            self.update_on_manifold(x)\n\n        return self.z.reshape((-1, 7))\n\n\nclass PoseGraphFC:\n    def __init__(self):\n        \"\"\"\n        implements pose graph optimization from\n        \"Hybrid Hessians for Optimization of Pose Graphs\" - Y. LeCun et al\n        and \"A Tutorial on Graph-Based SLAM\" - W. Burgard et al\n        fully connected version\n        \"\"\"\n        self.N = 0\n        self.z = np.zeros((0, 0))\n\n    def jacobian(self, L_ax, L_aq, L_rx, L_rq):\n        # 6 because updates for rotation are on manifold\n        J = np.zeros((0, 6*self.N))\n\n        # unary constraints\n        for i in range(self.N):\n            # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            jt[:, 6*i: 6*i+3] = np.eye(3)\n            J = np.vstack((J, np.dot(L_ax, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            jr[:, 6*i+3: 6*i+6] = m_rot(x=self.z[7*i+3: 7*i+7])\n            J = np.vstack((J, np.dot(L_aq, jr)))\n\n        # pairwise constraints\n        for i in range(self.N):\n            for j in range(i+1, self.N):\n                # translation constraint\n                jt = np.zeros((3, J.shape[1]))\n                dt = dqstq_t(q=self.z[7*i+3: 7*i+7])\n                # dt = np.eye(3)\n                jt[:, 6*i: 6*i+3] = -dt\n                jt[:, 6*j: 6*j+3] = dt\n                # m = m_rot(x=self.z[7*i+3 : 7*i+7])\n                # a = dqstq_q(q=self.z[7*i+3 : 7*i+7],\n                #             t=self.z[7*(i+1) : 7*(i+1)+3]-self.z[7*i : 7*i+3])\n                # jt[:, 6*i+3 : 6*i+6] = np.dot(a, m)\n                J = np.vstack((J, np.dot(L_rx, jt)))\n\n                # rotation constraint\n                jr = np.zeros((4, J.shape[1]))\n                m = m_rot(x=self.z[7*i+3: 7*i+7])\n                a = dpsq_p(q=self.z[7*j+3: 7*j+7])\n                jr[:, 6*i+3: 6*i+6] = np.dot(a, m)\n                m = m_rot(x=self.z[7*j+3: 7*j+7])\n                b = dpsq_q(p=self.z[7*i+3: 7*i+7])\n                jr[:, 6*j+3: 6*j+6] = np.dot(b, m)\n                J = np.vstack((J, np.dot(L_rq, jr)))\n\n        return J\n\n    def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):\n        \"\"\"\n        computes the residuals\n        :param poses: N x 7\n        :param vos: (N-1) x 7\n        :param L_ax: 3 x 3\n        :param L_aq: 4 x 4\n        :param L_rx: 3 x 3\n        :param L_rq: 4 x 4\n        :return: \n        \"\"\"\n        r = np.zeros((0, 1))\n\n        # unary residuals\n        L = np.zeros((7, 7))\n        L[:3, :3] = L_ax\n        L[3:, 3:] = L_aq\n        for i in range(self.N):\n            rr = self.z[7*i: 7*(i+1)] - np.reshape(poses[i], (-1, 1))\n            r = np.vstack((r, np.dot(L, rr)))\n\n        # pairwise residuals\n        k = 0\n        for i in range(self.N):\n            for j in range(i+1, self.N):\n                # translation residual\n                v = self.z[7*j:7*j+3, 0]-self.z[7*i:7*i+3, 0]\n                q = txq.qinverse(self.z[7*i+3:7*i+7, 0])\n                rt = txq.rotate_vector(v, q)\n                rt = rt[:, np.newaxis] - vos[k, :3].reshape((-1, 1))\n                # rt = self.z[7*(i+1) : 7*(i+1)+3] - self.z[7*i : 7*i+3] - \\\n                #     vos[i, :3].reshape((-1, 1))\n                r = np.vstack((r, np.dot(L_rx, rt)))\n\n                # rotation residual\n                q0 = self.z[7*i+3: 7*i+7].squeeze()\n                q1 = self.z[7*j+3: 7*j+7].squeeze()\n                qvo = txq.qmult(txq.qinverse(q0), q1).reshape((-1, 1))\n                rq = qvo - vos[k, 3:].reshape((-1, 1))\n                r = np.vstack((r, np.dot(L_rq, rq)))\n                k += 1\n\n        return r\n\n    def update_on_manifold(self, x):\n        \"\"\"\n        Updates the state vector on manifold\n        :param x: manifold increment, column vector\n        :return: \n        \"\"\"\n        for i in range(self.N):\n            # update translation\n            t = x[6*i: 6*i+3]\n            self.z[7*i: 7*i+3] += t\n\n            # update rotation\n            qm = x[6*i+3: 6*i+6]  # quaternion on the manifold\n            dq = np.zeros(4)\n            # method in Burgard paper\n            # dq[1:] = qm.squeeze()\n            # dq[0] = math.sqrt(1 - sum(np.square(qm)))  # incremental quaternion\n            # method of exponential map\n            n = np.linalg.norm(qm)\n            dq[0] = math.cos(n)\n            dq[1:] = np.sinc(n/np.pi) * qm.squeeze()\n            q = self.z[7*i+3: 7*i+7].squeeze()\n            q = txq.qmult(q, dq).reshape((-1, 1))\n            self.z[7*i+3: 7*i+7] = q\n\n    def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):\n        \"\"\"\n        run PGO, with init = poses\n        :param poses:\n        :param vos:\n        :param sax: sigma for absolute translation\n        :param saq: sigma for absolute rotation\n        :param srx: sigma for relative translation\n        :param srq: sigma for relative rotation\n        :param n_iters:\n        :return:\n        \"\"\"\n        self.N = len(poses)\n        # init state vector with the predicted poses\n        self.z = np.reshape(poses.copy(), (-1, 1))\n\n        # construct the information matrices\n        L_ax = np.linalg.cholesky(np.eye(3) / sax)\n        L_aq = np.linalg.cholesky(np.eye(4) / saq)\n        L_rx = np.linalg.cholesky(np.eye(3) / srx)\n        L_rq = np.linalg.cholesky(np.eye(4) / srq)\n\n        for n_iter in range(n_iters):\n            J = self.jacobian(L_ax.T, L_aq.T, L_rx.T, L_rq.T)\n            r = self.residuals(poses.copy(), vos.copy(), L_ax.T, L_aq.T, L_rx.T,\n                               L_rq.T)\n            H = np.dot(J.T, J)  # hessian\n            b = np.dot(J.T, r)  # residuals\n\n            # solve Hx = -b for x\n            R = slin.cholesky(H)  # H = R' R\n            y = slin.solve_triangular(R.T, -b)\n            x = slin.solve_triangular(R, y)\n\n            self.update_on_manifold(x)\n\n        return self.z.reshape((-1, 7))\n\n\ndef optimize_poses(pred_poses, vos=None, fc_vos=False, target_poses=None,\n                   sax=1, saq=1, srx=1, srq=1):\n    \"\"\"\n    optimizes poses using either the VOs or the target poses (calculates VOs\n    from them)\n    :param pred_poses: N x 7\n    :param vos: (N-1) x 7\n    :param fc_vos: whether to use relative transforms between all frames in a fully\n    connected manner, not just consecutive frames\n    :param target_poses: N x 7\n    :param: sax: covariance of pose translation (1 number)\n    :param: saq: covariance of pose rotation (1 number)\n    :param: srx: covariance of VO translation (1 number)\n    :param: srq: covariance of VO rotation (1 number)\n    :return:\n    \"\"\"\n    pgo = PoseGraphFC() if fc_vos else PoseGraph()\n    if vos is None:\n        if target_poses is not None:\n            # calculate the VOs (in the pred_poses frame)\n            vos = np.zeros((len(target_poses)-1, 7))\n            for i in range(len(vos)):\n                vos[i, :3] = target_poses[i+1, :3] - target_poses[i, :3]\n                q0 = target_poses[i, 3:]\n                q1 = target_poses[i+1, 3:]\n                vos[i, 3:] = txq.qmult(txq.qinverse(q0), q1)\n        else:\n            print('Specify either VO or target poses')\n            return None\n    optim_poses = pgo.optimize(poses=pred_poses, vos=vos, sax=sax, saq=saq,\n                               srx=srx, srq=srq)\n    return optim_poses\n\n\ndef align_3d_pts(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (3xn)\n    x2 -- second trajectory (3xn)\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    #x2a = s * np.dot(R, x1-t)\n    #error = x2a - x2\n\n    return R, t, s\n\n\ndef align_2d_pts(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (2xn)\n    x2 -- second trajectory (2xn)\n\n    Output:\n    R -- rotation matrix (2x2)\n    t -- translation vector (2x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((2, 2))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(2)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[1, 1] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    #x2a = s * np.dot(R, x1-t)\n    #error = x2a - x2\n\n    return R, t, s\n\n\ndef align_3d_pts_noscale(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (3xn)\n    x2 -- second trajectory (3xn)\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    #s = np.asscalar(np.sqrt(r2/r1))\n    s = 1\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    #x2a = s * np.dot(R, x1-t)\n    #error = x2a - x2\n\n    return R, t, s\n\n\ndef align_2d_pts_noscale(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (2xn)\n    x2 -- second trajectory (2xn)\n\n    Output:\n    R -- rotation matrix (2x2)\n    t -- translation vector (2x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((2, 2))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    #s = np.asscalar(np.sqrt(r2/r1))\n    s = 1\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(2)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[1, 1] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    #x2a = s * np.dot(R, x1-t)\n    #error = x2a - x2\n\n    return R, t, s\n\n\ndef align_camera_poses(o1, o2, R1, R2, use_rotation_constraint=True):\n    \"\"\"Align two sets of camera poses (R1,o1/R2,o2) using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(o1-t) = o2   (1)\n\n            R*R1 = R2         (2)\n\n    where R1/R2 are the camera-to-world matrices, o1/o2 are the center\n    of the cameras.\n\n    Input:\n    o1 -- camera centers (3xn)\n    o2 -- camera centers (3xn)\n    R1 -- camera poses (camera-to-world matrices) (nx3x3)\n    R2 -- camera poses (camera-to-world matrices) (nx3x3)\n    use_rotation_constraint -- if False, uses only Eq(1) to solve.\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    s -- scale (1x1)\n\n    Note, when use_rotation_constraint=False, it is the same problem as\n    above, i.e., to align two sets of 3D points.\n\n    When use_rotation_constraint=True, we note Eq(2) is the same\n    equation as Eq(1), after we zero-center and remove the scale. So, we\n    can use the same approach (SVD).\n    written by Jinwei Gu\n    \"\"\"\n    if not use_rotation_constraint:\n        return align_3d_pts(o1, o2)\n\n    o1c = o1.mean(1, keepdims=True)\n    o2c = o2.mean(1, keepdims=True)\n    o1_zerocentered = o1 - o1c\n    o2_zerocentered = o2 - o2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(o1.shape[1]):\n        a = o1_zerocentered[:, i]\n        b = o2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    # add rotation constraints\n    for i in range(o1.shape[1]):\n        d1 = np.squeeze(R1[i, :, :])\n        d2 = np.squeeze(R2[i, :, :])\n        for c in range(3):\n            a = d1[:, c]\n            b = d2[:, c]\n            W += np.outer(b, a)\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = o1c - (1/s) * np.dot(R.transpose(), o2c)\n\n    # ---- align ----\n    #o2a = s * np.dot(R, o1-t)\n    #R2a = np.dot(R, R1)\n\n    return R, t, s\n\n\ndef test_align_3d_pts():\n    import transforms3d.euler as txe\n    N = 10\n    x1 = np.random.rand(3, N)\n\n    noise = np.random.rand(3, N)*0.01\n\n    s = np.random.rand()\n    t = np.random.rand(3, 1)\n    R = txe.euler2mat(np.random.rand(), np.random.rand(), np.random.rand())\n    R = R[:3, :3]\n\n    x2 = s*np.dot(R, x1-t) + noise\n\n    Re, te, se = align_3d_pts(x1, x2)\n\n    print('scale ', s, se)\n    print('rotation matrx ', R, Re)\n    print('translation ', t, te)\n\n\ndef test_align_camera_poses():\n    import transforms3d.euler as txe\n\n    N = 10\n    o1 = np.random.rand(3, N)\n\n    noise = np.random.rand(3, N)*0.01\n\n    s = np.random.rand()\n    t = np.random.rand(3, 1)\n    R = txe.euler2mat(np.random.rand(), np.random.rand(), np.random.rand())\n    R = R[:3, :3]\n\n    o2 = s*np.dot(R, o1-t) + noise\n\n    R1 = np.zeros((N, 3, 3))\n    R2 = np.zeros((N, 3, 3))\n    for i in range(N):\n        Ri = txe.euler2mat(\n            np.random.rand(), np.random.rand(), np.random.rand())\n        R1[i, :, :] = Ri[:3, :3]\n        R2[i, :, :] = np.dot(R, Ri[:3, :3])\n\n    Re1, te1, se1 = align_camera_poses(o1, o2, R1, R2, False)\n    Re2, te2, se2 = align_camera_poses(o1, o2, R1, R2, True)\n\n    print('scale ', s, se1, se2)\n    print('rotation matrx ', R, Re1, Re2)\n    print('translation ', t, te1, te2)\n\n\ndef pgo_test_poses():\n    \"\"\"\n    generates test poses and vos for the various PGO implementations\n    :return:\n    \"\"\"\n    poses = np.zeros((3, 7))\n    for i in range(poses.shape[0]):\n        poses[i, :3] = i\n        angle = math.radians(10*i)\n        R = txe.euler2mat(angle, angle, angle)\n        q = txq.mat2quat(R)\n        poses[i, 3:] = q\n\n    vos = np.zeros((poses.shape[0]-1, 7))\n    for i in range(vos.shape[0]):\n        vos[i, 0] = 1.5\n        vos[i, 1] = 0.5\n        vos[i, 2] = 1.0\n        R = txe.euler2mat(math.radians(15), math.radians(10), math.radians(5))\n        q = txq.mat2quat(R)\n        vos[i, 3:] = q\n\n    return poses, vos\n\n\ndef pgo_test_poses1():\n    poses = np.zeros((3, 7))\n    R = txe.euler2mat(0, 0, np.deg2rad(45))\n    q = txq.mat2quat(R)\n    poses[:, 3:] = q\n    for i in range(len(poses)):\n        poses[i, :3] = np.asarray([i, i, 0])\n\n    pt = np.zeros((len(poses), 6))\n    pt[:, :3] = poses[:, :3]\n    for i, p in enumerate(poses):\n        pt[i, 3:] = qlog(p[3:])\n    pt = torch.from_numpy(pt.astype(np.float32))\n    vost = calc_vos_safe_fc(pt.unsqueeze(0))[0].numpy()\n    vos = np.zeros((len(vost), 7))\n    vos[:, :3] = vost[:, :3]\n    for i, p in enumerate(vost):\n        vos[i, 3:] = qexp(p[3:])\n\n    # perturbation\n    vos[0, 0] = np.sqrt(2) - 0.5\n    vos[1, 0] = np.sqrt(2) - 0.5\n\n    return poses, vos\n\n\ndef print_poses(poses):\n    print('translations')\n    print(poses[:, :3])\n    print('euler')\n    for i in range(poses.shape[0]):\n        a = txe.mat2euler(txq.quat2mat(poses[i, 3:]))\n        print([np.rad2deg(aa) for aa in a])\n\n\ndef test_pgo():\n    \"\"\"\n    Tests the full pose graph optimization implementation\n    :return: bool\n    \"\"\"\n    pred_poses, vos = pgo_test_poses1()\n    print('pred poses')\n    print_poses(pred_poses)\n    print('vos')\n    print_poses(vos)\n\n    pgo = PoseGraph()\n    optimized_poses = pgo.optimize(pred_poses, vos)\n\n    print('optimized')\n    print_poses(optimized_poses)\n\n\ndef test_pose_utils():\n    \"\"\"\n    Tests the pose utils\n    :return: \n    \"\"\"\n    TEST_COMPOSE = True\n    TEST_INV = True\n\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n\n    if TEST_COMPOSE:\n        print('Testing pose composing...')\n        R1 = txe.euler2mat(ra(1), ra(1), ra(1))\n        t1 = np.random.rand(3)\n        R2 = txe.euler2mat(ra(1), ra(1), ra(1))\n        t2 = np.random.rand(3)\n\n        # homogeneous matrix method\n        R = np.dot(R1, R2)\n        t = t1 + np.dot(R1, t2)\n        print('From homogeneous matrices, t = ')\n        print(t)\n        print('R = ')\n        print(R)\n\n        # quaternion method\n        q1 = txq.mat2quat(R1)\n        q2 = txq.mat2quat(R2)\n\n        p1 = torch.cat((torch.from_numpy(t1), torch.from_numpy(q1)))\n        p2 = torch.cat((torch.from_numpy(t2), torch.from_numpy(q2)))\n        p = compose_pose_quaternion(\n            torch.unsqueeze(p1, 0), torch.unsqueeze(p2, 0))\n        t = p[:, :3].numpy().squeeze()\n        q = p[:, 3:].numpy().squeeze()\n        print('From quaternions, t = ')\n        print(t)\n        print('R = ')\n        print(txe.quat2mat(q))\n\n    if TEST_INV:\n        print('Testing pose inversion...')\n        R = txe.euler2mat(ra(1), ra(1), ra(1))\n        t = np.random.rand(3)\n        T = np.eye(4)\n        T[:3, :3] = R\n        T[:3, -1] = t\n\n        q = txq.mat2quat(R)\n        p = torch.cat((torch.from_numpy(t), torch.from_numpy(q)))\n        pinv = invert_pose_quaternion(torch.unsqueeze(p, 0))\n        tinv, qinv = pinv[:, :3], pinv[:, 3:]\n        Rinv = txq.quat2mat(qinv.numpy().squeeze())\n        Tinv = np.eye(4)\n        Tinv[:3, :3] = Rinv\n        Tinv[:3, -1] = tinv.numpy().squeeze()\n        print('T * T^(-1) = ')\n        print(np.dot(T, Tinv))\n\n\ndef test_q_error():\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n    # rotation along x axis\n    a1 = ra(1)\n    a2 = ra(1)\n    q1 = txq.mat2quat(txe.euler2mat(a1, 0, 0))\n    q2 = txq.mat2quat(txe.euler2mat(a2, 0, 0))\n    a1 = np.rad2deg(a1)\n    a2 = np.rad2deg(a2)\n    print('Angles: {:f}, {:f}, difference = {:f}'.format(a1, a2, a1-a2))\n    print('Error: {:f}'.format(quaternion_angular_error(q1, q2)))\n\n\ndef test_log_q_error():\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n    # rotation along x axis\n    a1 = ra(1)\n    a2 = ra(1)\n    q1 = txq.mat2quat(txe.euler2mat(0, a1, 0))\n    q2 = txq.mat2quat(txe.euler2mat(0, a2, 0))\n    # apply log map\n    q1 = np.arccos(q1[0]) * q1[1:] / np.linalg.norm(q1[1:])\n    q2 = np.arccos(q2[0]) * q2[1:] / np.linalg.norm(q2[1:])\n    a1 = np.rad2deg(a1)\n    a2 = np.rad2deg(a2)\n    print('Angles: {:f}, {:f}, difference = {:f}'.format(a1, a2, a1-a2))\n    print('Error: {:f}'.format(log_quaternion_angular_error(q1, q2)))\n\n\nif __name__ == '__main__':\n    test_pgo()\n    # test_dumb_pgo()\n    # test_align_camera_poses()\n    # test_q_error()\n    # test_log_q_error()\n"
  },
  {
    "path": "utils/pose_utils_np.py",
    "content": "\"\"\"\nCopyright (C) 2018 NVIDIA Corporation.  All rights reserved.\nLicensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).\n\"\"\"\n\nimport torch\nfrom torch.nn import Module\nfrom torch.autograd import Variable\nfrom torch.nn.functional import pad\nimport numpy as np\nimport scipy.linalg as slin\nimport math\nimport transforms3d.quaternions as txq\nimport transforms3d.euler as txe\nimport quaternion\nimport bisect\n# see for formulas:\n# https://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-801-machine-vision-fall-2004/readings/quaternions.pdf\n# and \"Quaternion and Rotation\" - Yan-Bin Jia, September 18, 2016\n# from IPython.core.debugger import set_trace\n\n# PYTORCH\n\ndef tq2RT(poses, square=False):\n    \"\"\"\n    :param poses: N x 7, (t,q)\n    :return: (N,3,4) \n    \"\"\"\n    N,_ = poses.shape\n    T = poses[:,:3]\n    q = quaternion.from_float_array(poses[:,3:])\n\n    R = quaternion.as_rotation_matrix(q) #Nx3x3\n    RT = np.concatenate([R,T[...,None]], axis=-1) #Nx3x4\n\n    if square:\n        padding = np.zeros([N,1,4])\n        padding[:,:,-1] = 1 \n        RT = np.concatenate([RT,padding], axis=1) #Nx4x4\n    return RT \ndef RT2tq(poses, square=False):\n    \"\"\"\n    !!NOT TESETED!!\n    :param poses: N x 3 x 4, (R|T)\n    :return: (N, 7) \n    \"\"\"\n    N,_,_ = poses.shape\n    R = poses[:,:,:3]\n    T = poses[:,:,3:] # Nx3x1\n\n    q = quaternion.as_float_array(quaternion.from_rotation_matrix(R)) #Nx4\n    t= T.squeeze(-1) \n\n    tq = np.concatenate([t,q], axis=-1)\n\n    return tq \n\ndef pose_interp(poses, timestamps_in, timestamps_out, r_interp='slerp'):\n    \"\"\"\n    :param poses: N x 7, (t,q)\n    :param timestamps: (N,) \n    :param t: (K,) \n    :return: (K,) \n\n    \"\"\"\n  \n    \n    # assert t_interp in ['linear', 'spline']\n    assert r_interp in ['slerp', 'squad']\n\n    assert len(poses)>1\n    assert len(poses) == len(timestamps_in)\n\n    input_ts = poses[:,:3]\n    input_rs= poses[:,3:] #quaternions\n    timestamps_in = np.array(timestamps_in)\n    #sort the inputs\n    inds = np.argsort(timestamps_in)\n    poses = poses[inds]\n    timestamps_in = timestamps_in[inds]\n\n    if r_interp == 'squad':\n        input_rs_ = quaternion.from_float_array(input_rs) \n        output_rs = quaternion.squad( input_rs, timestamps_in, timestamps_out)\n        output_rs = quaternion.as_float_array(output_rs)\n    elif r_interp == 'slerp':\n        output_rs = []\n        for t in timestamps_out:\n            input_rs_ = quaternion.from_float_array(input_rs) \n            idx = bisect.bisect_left(timestamps_in)\n            output_r =  quaternion.slerp(input_rs_[idx],input_rs_[idx+1], timestamps_in[idx], timestamps_in[idx+1],t )\n            output_r = quaternion.as_float_array(output_r)\n            output_rs.append(output_r)\n\n    output_ts = []\n    for t in timestamps_out:\n        idx = bisect_left.bisect_left(timestamps_in)\n        if idx>=len(timestamps_in)-1:\n            idx -= 1\n        \n        t1 = timestamps_in[idx]\n        t2 = timestamps_in[idx+1]\n\n        output_t = ((t-t1)*input_ts[idx+1] +  (t2-t) *input_ts[idx]) / (t2-t1)\n\n        output_ts.append(output_t)\n\n    output_ts =np.concatenate(output_ts, axis=0 )\n    output_rs =np.concatenate(output_rs, axis=0 )\n\n    new_pose = np.concatenate([output_ts, output_rs], axis=1) \n    return new_pose\n\ndef vdot(v1, v2):\n    \"\"\"\n    Dot product along the dim=1\n    :param v1: N x d\n    :param v2: N x d\n    :return: N x 1\n    \"\"\"\n    # out = torch.mul(v1, v2)\n    out = v1 * v2\n    out = np.sum(out, axis=1)\n    return out\n\n\ndef normalize(x, p=2, dim=0, eps=1e-6):\n    \"\"\"\n    Divides a tensor along a certain dim by the Lp norm\n    :param x:\n    :param p: Lp norm\n    :param dim: Dimension to normalize along\n    :return:\n    \"\"\"\n    # xn=x.norm(p = p, dim = dim)\n    xn = np.linalg.norm(x, ord=p, axis=dim, keepdims=True)\n    # x=x / xn.unsqueeze(dim = dim)\n    x = x / (xn+eps)\n\n    # x *= np.sign(x[0])  #added on 22/4/2020\n    return x\n\n\ndef qmult(q1, q2):\n    \"\"\"\n    Multiply 2 quaternions\n    :param q1: Tensor N x 4\n    :param q2: Tensor N x 4\n    :return: quaternion product, Tensor N x 4\n    \"\"\"\n    q1s, q1v = q1[:, :1], q1[:, 1:]\n    q2s, q2v = q2[:, :1], q2[:, 1:]\n\n    qs = q1s*q2s - vdot(q1v, q2v)\n    # qv=q1v.mul(q2s.expand_as(q1v)) + q2v.mul(q1s.expand_as(q2v)) +\n    #     torch.cross(q1v, q2v, dim = 1)\n    qv = q1v*q2s + q2v*q1s + np.cross(q1v, q2v, axis=1)\n    q = np.concatenate((qs, qv), axis=1)\n\n    # normalize\n    q = normalize(q, dim=1)\n\n    return q\n\n\ndef qinv(q):\n    \"\"\"\n    Inverts quaternions\n    :param q: N x 4\n    :return: q*: N x 4\n    \"\"\"\n    # q_inv = torch.cat((q[:, :1], -q[:, 1:]), dim=1)\n    q_inv = np.concatenate((q[:, :1], -q[:, 1:]), axis=1)\n    return q_inv\n\n\ndef qexp_t(q):\n    \"\"\"\n    Applies exponential map to log quaternion\n    :param q: N x 3\n    :return: N x 4\n    \"\"\"\n    n = torch.norm(q, p=2, dim=1, keepdim=True)\n    n = torch.clamp(n, min=1e-8)\n    q = q * torch.sin(n)\n    q = q / n\n    q = torch.cat((torch.cos(n), q), dim=1)\n    return q\n\n\ndef qlog_t(q):\n    \"\"\"\n    Applies the log map to a quaternion\n    :param q: N x 4\n    :return: N x 3\n    \"\"\"\n    n = torch.norm(q[:, 1:], p=2, dim=1, keepdim=True)\n    n = torch.clamp(n, min=1e-8)\n    q = q[:, 1:] * torch.acos(torch.clamp(q[:, :1], min=-1.0, max=1.0))\n    q = q / n\n    return q\n\n\ndef qexp_t_safe(q):\n    \"\"\"\n    Applies exponential map to log quaternion (safe implementation that does not\n    maintain gradient flow)\n    :param q: N x 3\n    :return: N x 4\n    \"\"\"\n    q = torch.from_numpy(np.asarray([qexp(qq) for qq in q.numpy()],\n                                    dtype=np.float32))\n    return q\n\n\ndef qlog_t_safe(q):\n    \"\"\"\n    Applies the log map to a quaternion (safe implementation that does not\n    maintain gradient flow)\n    :param q: N x 4\n    :return: N x 3\n    \"\"\"\n    q = torch.from_numpy(np.asarray([qlog(qq) for qq in q.numpy()],\n                                    dtype=np.float32))\n    return q\n\n\ndef rotate_vec_by_q(t, q):\n    \"\"\"\n    rotates vector t by quaternion q\n    :param t: vector, Tensor N x 3\n    :param q: quaternion, Tensor N x 4\n    :return: t rotated by q: t' = t + 2*qs*(qv x t) + 2*qv x (qv x r)\n    \"\"\"\n    qs, qv = q[:, :1], q[:, 1:]\n    # b = torch.cross(qv, t, dim=1)\n    b = np.cross(qv, t, axis=1)\n    # c = 2 * torch.cross(qv, b, dim=1)\n    c = 2 * np.cross(qv, b, axis=1)\n    # b = 2 * b.mul(qs.expand_as(b))\n    b = 2 * b*qs\n    tq = t + b + c\n    return tq\n\n\ndef compose_pose_quaternion(p1, p2):\n    \"\"\"\n    pyTorch implementation\n    :param p1: input pose, Tensor N x 7\n    :param p2: pose to apply, Tensor N x 7\n    :return: output pose, Tensor N x 7\n    all poses are translation + quaternion\n    #!comments: first apply p2 and then p1 !!\n    \"\"\"\n\n    p1t, p1q = p1[:, :3], p1[:, 3:]\n    p2t, p2q = p2[:, :3], p2[:, 3:]\n\n    q = qmult(p1q, p2q)\n    t = p1t + rotate_vec_by_q(p2t, p1q)\n    return np.concatenate((t, q), axis=1)\n\n\ndef invert_pose_quaternion(p):\n    \"\"\"\n    inverts the pose\n    :param p: pose, Tensor N x 7\n    :return: inverted pose\n    \"\"\"\n    t, q = p[:, :3], p[:, 3:]\n    q_inv = qinv(q)\n    tinv = -rotate_vec_by_q(t, q_inv)\n    return np.concatenate((tinv, q_inv), axis=1)\n\n\ndef calc_vo(p0, p1):\n    \"\"\"\n    calculates VO (in the p0 frame) from 2 poses\n    :param p0: N x 7\n    :param p1: N x 7\n    \"\"\"\n    # assert p0.shape==p1.shape\n    return compose_pose_quaternion(invert_pose_quaternion(p0), p1)\n\n\ndef calc_vo_logq(p0, p1):\n    \"\"\"\n    VO (in the p0 frame) (logq)\n    :param p0: N x 6\n    :param p1: N x 6\n    :return: N-1 x 6\n    \"\"\"\n    q0 = qexp_t(p0[:, 3:])\n    q1 = qexp_t(p1[:, 3:])\n    vos = calc_vo(torch.cat((p0[:, :3], q0), dim=1), torch.cat((p1[:, :3], q1),\n                                                               dim=1))\n    vos_q = qlog_t(vos[:, 3:])\n    return torch.cat((vos[:, :3], vos_q), dim=1)\n\n\ndef calc_vo_relative(p0, p1):\n    \"\"\"\n    calculates VO (in the world frame) from 2 poses\n    :param p0: N x 7\n    :param p1: N x 7\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    vos_q = qmult(qinv(p0[:, 3:]), p1[:, 3:])\n    return np.concatenate((vos_t, vos_q), axis=1)\n\n\ndef calc_vo_relative_logq(p0, p1):\n    \"\"\"\n    Calculates VO (in the world frame) from 2 poses (log q)\n    :param p0: N x 6\n    :param p1: N x 6\n    :return:\n    \"\"\"\n    q0 = qexp_t(p0[:, 3:])\n    q1 = qexp_t(p1[:, 3:])\n    vos = calc_vo_relative(torch.cat((p0[:, :3], q0), dim=1),\n                           torch.cat((p1[:, :3], q1), dim=1))\n    vos_q = qlog_t(vos[:, 3:])\n    return torch.cat((vos[:, :3], vos_q), dim=1)\n\n\ndef calc_vo_relative_logq_safe(p0, p1):\n    \"\"\"\n    Calculates VO (in the world frame) from 2 poses (log q) through numpy fns\n    :param p0: N x 6\n    :param p1: N x 6\n    :return:\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    q0 = qexp_t_safe(p0[:, 3:])\n    q1 = qexp_t_safe(p1[:, 3:])\n    vos_q = qmult(qinv(q0), q1)\n    vos_q = qlog_t_safe(vos_q)\n    return torch.cat((vos_t, vos_q), dim=1)\n\n\ndef calc_vo_logq_safe(p0, p1):\n    \"\"\"\n    VO in the p0 frame using numpy fns\n    :param p0:\n    :param p1:\n    :return:\n    \"\"\"\n    vos_t = p1[:, :3] - p0[:, :3]\n    q0 = qexp_t_safe(p0[:, 3:])\n    q1 = qexp_t_safe(p1[:, 3:])\n    vos_t = rotate_vec_by_q(vos_t, qinv(q0))\n    vos_q = qmult(qinv(q0), q1)\n    vos_q = qlog_t_safe(vos_q)\n    return torch.cat((vos_t, vos_q), dim=1)\n\n\ndef calc_vos_simple(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [p[i+1].unsqueeze(0) - p[i].unsqueeze(0)\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n\n    return vos\n\n\ndef calc_vos(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (in the p0 frame)\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_logq(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_relative(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (in the world frame)\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_relative_logq(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_safe(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses\n    :param poses: N x T x 7\n    :return: N x (T-1) x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = [calc_vo_logq_safe(p[i].unsqueeze(0), p[i+1].unsqueeze(0))\n                for i in range(len(p)-1)]\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n\ndef calc_vos_safe_fc(poses):\n    \"\"\"\n    calculate the VOs, from a list of consecutive poses (fully connected)\n    :param poses: N x T x 7\n    :return: N x TC2 x 7\n    \"\"\"\n    vos = []\n    for p in poses:\n        pvos = []\n        for i in range(p.size(0)):\n            for j in range(i+1, p.size(0)):\n                pvos.append(calc_vo_logq_safe(\n                    p[i].unsqueeze(0), p[j].unsqueeze(0)))\n        vos.append(torch.cat(pvos, dim=0))\n    vos = torch.stack(vos, dim=0)\n    return vos\n\n# NUMPY\n\n\ndef qlog(q):\n    \"\"\"\n    Applies logarithm map to q\n    :param q: (4,)\n    :return: (3,)\n    \"\"\"\n    if all(q[1:] == 0):\n        q = np.zeros(3)\n    else:\n        q = np.arccos(q[0]) * q[1:] / np.linalg.norm(q[1:])\n    return q\n\n\ndef qexp(q):\n    \"\"\"\n    Applies the exponential map to q\n    :param q: (3,)\n    :return: (4,)\n    \"\"\"\n    n = np.linalg.norm(q)\n    q = np.hstack((np.cos(n), np.sinc(n/np.pi)*q))\n    return q\n\n\ndef process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):\n    \"\"\"\n    processes the 1x12 raw pose from dataset by aligning and then normalizing\n    :param poses_in: N x 12\n    :param mean_t: 3\n    :param std_t: 3\n    :param align_R: 3 x 3\n    :param align_t: 3\n    :param align_s: 1\n    :return: processed poses (translation + quaternion) N x 7\n    \"\"\"\n    poses_out = np.zeros((len(poses_in), 6))\n    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]\n\n    # align\n    for i in range(len(poses_out)):\n        R = poses_in[i].reshape((3, 4))[:3, :3]\n        q = txq.mat2quat(np.dot(align_R, R))\n        q *= np.sign(q[0])  # constrain to hemisphere\n        q = qlog(q)\n        poses_out[i, 3:] = q\n        t = poses_out[i, :3] - align_t\n        poses_out[i, :3] = align_s * \\\n            np.dot(align_R, t[:, np.newaxis]).squeeze()\n\n    # normalize translation\n    poses_out[:, :3] -= mean_t\n    poses_out[:, :3] /= std_t\n    return poses_out\n\n\ndef log_quaternion_angular_error(q1, q2):\n    return quaternion_angular_error(qexp(q1), qexp(q2))\n\n\ndef quaternion_angular_error(q1, q2):\n    \"\"\"\n    angular error between two quaternions\n    :param q1: (4, )\n    :param q2: (4, )\n    :return:\n    \"\"\"\n    d = abs(np.dot(q1, q2))\n    d = min(1.0, max(-1.0, d))\n    theta = 2 * np.arccos(d) * 180 / np.pi\n    return theta\n\n\ndef skew(x):\n    \"\"\"\n    returns skew symmetric matrix from vector\n    :param x: 3 x 1\n    :return:\n    \"\"\"\n    s = np.asarray([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])\n    return s\n\n\ndef dpq_q(p):\n    \"\"\"\n    returns the jacobian of quaternion product pq w.r.t. q\n    :param p: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = p[0]\n    J[0, 1:] = -p[1:].squeeze()\n    J[1:, 0] = p[1:].squeeze()\n    J[1:, 1:] = p[0] * np.eye(3) + skew(p[1:])\n    return J\n\n\ndef dpsq_q(p):\n    \"\"\"\n    returns the jacobian of quaternion product (p*)q w.r.t. q\n    :param p: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = p[0]\n    J[0, 1:] = -p[1:].squeeze()\n    J[1:, 0] = -p[1:].squeeze()\n    J[1:, 1:] = p[0] * np.eye(3) - skew(p[1:])\n    return J\n\n\ndef dpsq_p(q):\n    \"\"\"\n    returns the jacobian of quaternion product (p*)q w.r.t. p\n    :param q: 4 x 1\n    :return: 4 x 4\n    \"\"\"\n    J = np.zeros((4, 4))\n    J[0, 0] = q[0]\n    J[0, 1:] = q[1:].squeeze()\n    J[1:, 0] = q[1:].squeeze()\n    J[1:, 1:] = -q[0] * np.eye(3) + skew(q[1:])\n    return J\n\n\ndef dqstq_q(q, t):\n    \"\"\"\n    jacobian of q* t q w.r.t. q\n    :param q: 4 x 1\n    :param t: 3 x 1\n    :return: 3 x 4\n    \"\"\"\n    J = np.zeros((3, 4))\n    J[:, :1] = q[0]*t - np.cross(q[1:], t, axis=0)\n    J[:, 1:] = -np.dot(t, q[1:].T) + np.dot(t.T, q[1:])*np.eye(3) + \\\n        np.dot(q[1:], t.T) + q[0]*skew(t)\n    J *= 2\n    return J\n\n\ndef dqstq_t(q):\n    \"\"\"\n    jacobian of q* t q w.r.t. t\n    :param q: 4 x 1\n    :return: 3 x 3\n    \"\"\"\n    J = (q[0]*q[0] - np.dot(q[1:].T, q[1:])) * np.eye(3) + 2*np.dot(q[1:], q[1:].T) -\\\n        2*q[0]*skew(q[1:])\n    return J\n\n\ndef m_rot(x):\n    \"\"\"\n    returns Jacobian of exponential map w.r.t. manifold increment\n    :param x: part of state vector affected by increment, 4 x 1\n    :return: 4 x 3\n    \"\"\"\n    # jacobian of full q wrt qm (quaternion update on manifold),\n    # evaluated at qv = (0, 0, 0)\n    # full q is derived using either the exponential map or q0 = sqrt(1-qm^2)\n    jm = np.vstack((np.zeros((1, 3)), np.eye(3)))  # 4 x 3\n    m = np.dot(dpq_q(p=x), jm)\n    return m\n\n\nclass PoseGraph:\n    def __init__(self):\n        \"\"\"\n        implements pose graph optimization from\n        \"Hybrid Hessians for Optimization of Pose Graphs\" - Y. LeCun et al\n        and \"A Tutorial on Graph-Based SLAM\" - W. Burgard et al\n        \"\"\"\n        self.N = 0\n        self.z = np.zeros((0, 0))\n\n    def jacobian(self, L_ax, L_aq, L_rx, L_rq):\n        # 6 because updates for rotation are on manifold\n        J = np.zeros((0, 6*self.N))\n\n        # unary constraints\n        for i in range(self.N):\n            # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            jt[:, 6*i: 6*i+3] = np.eye(3)\n            J = np.vstack((J, np.dot(L_ax, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            jr[:, 6*i+3: 6*i+6] = m_rot(x=self.z[7*i+3: 7*i+7])\n            J = np.vstack((J, np.dot(L_aq, jr)))\n\n        # pairwise constraints\n        for i in range(self.N-1):\n                # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            dt = dqstq_t(q=self.z[7*i+3: 7*i+7])\n            # dt = np.eye(3)\n            jt[:, 6*i: 6*i+3] = -dt\n            jt[:, 6*(i+1): 6*(i+1)+3] = dt\n            # m = m_rot(x=self.z[7*i+3 : 7*i+7])\n            # a = dqstq_q(q=self.z[7*i+3 : 7*i+7],\n            #             t=self.z[7*(i+1) : 7*(i+1)+3]-self.z[7*i : 7*i+3])\n            # jt[:, 6*i+3 : 6*i+6] = np.dot(a, m)\n            J = np.vstack((J, np.dot(L_rx, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            m = m_rot(x=self.z[7*i+3: 7*i+7])\n            a = dpsq_p(q=self.z[7*(i+1)+3: 7*(i+1)+7])\n            jr[:, 6*i+3: 6*i+6] = np.dot(a, m)\n            m = m_rot(x=self.z[7*(i+1)+3: 7*(i+1)+7])\n            b = dpsq_q(p=self.z[7*i+3: 7*i+7])\n            jr[:, 6*(i+1)+3: 6*(i+1)+6] = np.dot(b, m)\n            J = np.vstack((J, np.dot(L_rq, jr)))\n\n        return J\n\n    def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):\n        \"\"\"\n        computes the residuals\n        :param poses: N x 7\n        :param vos: (N-1) x 7\n        :param L_ax: 3 x 3\n        :param L_aq: 4 x 4\n        :param L_rx: 3 x 3\n        :param L_rq: 4 x 4\n        :return:\n        \"\"\"\n        r = np.zeros((0, 1))\n\n        # unary residuals\n        L = np.zeros((7, 7))\n        L[:3, :3] = L_ax\n        L[3:, 3:] = L_aq\n        for i in range(self.N):\n            rr = self.z[7*i: 7*(i+1)] - np.reshape(poses[i], (-1, 1))\n            r = np.vstack((r, np.dot(L, rr)))\n\n        # pairwise residuals\n        for i in range(self.N-1):\n            # translation residual\n            v = self.z[7*(i+1):7*(i+1)+3, 0]-self.z[7*i:7*i+3, 0]\n            q = txq.qinverse(self.z[7*i+3:7*i+7, 0])\n            rt = txq.rotate_vector(v, q)\n            rt = rt[:, np.newaxis] - vos[i, :3].reshape((-1, 1))\n            # rt = self.z[7*(i+1) : 7*(i+1)+3] - self.z[7*i : 7*i+3] - \\\n            #     vos[i, :3].reshape((-1, 1))\n            r = np.vstack((r, np.dot(L_rx, rt)))\n\n            # rotation residual\n            q0 = self.z[7*i+3: 7*i+7].squeeze()\n            q1 = self.z[7*(i+1)+3: 7*(i+1)+7].squeeze()\n            qvo = txq.qmult(txq.qinverse(q0), q1).reshape((-1, 1))\n            rq = qvo - vos[i, 3:].reshape((-1, 1))\n            r = np.vstack((r, np.dot(L_rq, rq)))\n\n        return r\n\n    def update_on_manifold(self, x):\n        \"\"\"\n        Updates the state vector on manifold\n        :param x: manifold increment, column vector\n        :return:\n        \"\"\"\n        for i in range(self.N):\n            # update translation\n            t = x[6*i: 6*i+3]\n            self.z[7*i: 7*i+3] += t\n\n            # update rotation\n            qm = x[6*i+3: 6*i+6]  # quaternion on the manifold\n            dq = np.zeros(4)\n            # method in Burgard paper\n            # dq[1:] = qm.squeeze()\n            # dq[0] = math.sqrt(1 - sum(np.square(qm)))  # incremental quaternion\n            # method of exponential map\n            n = np.linalg.norm(qm)\n            dq[0] = math.cos(n)\n            dq[1:] = np.sinc(n/np.pi) * qm.squeeze()\n            q = self.z[7*i+3: 7*i+7].squeeze()\n            q = txq.qmult(q, dq).reshape((-1, 1))\n            self.z[7*i+3: 7*i+7] = q\n\n    def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):\n        \"\"\"\n        run PGO, with init = poses\n        :param poses:\n        :param vos:\n        :param sax: sigma for absolute translation\n        :param saq: sigma for absolute rotation\n        :param srx: sigma for relative translation\n        :param srq: sigma for relative rotation\n        :param n_iters:\n        :return:\n        \"\"\"\n        self.N = len(poses)\n        # init state vector with the predicted poses\n        self.z = np.reshape(poses.copy(), (-1, 1))\n\n        # construct the information matrices\n        L_ax = np.linalg.cholesky(np.eye(3) / sax)\n        L_aq = np.linalg.cholesky(np.eye(4) / saq)\n        L_rx = np.linalg.cholesky(np.eye(3) / srx)\n        L_rq = np.linalg.cholesky(np.eye(4) / srq)\n\n        for n_iter in range(n_iters):\n            J = self.jacobian(L_ax.T, L_aq.T, L_rx.T, L_rq.T)\n            r = self.residuals(poses.copy(), vos.copy(), L_ax.T, L_aq.T, L_rx.T,\n                               L_rq.T)\n            H = np.dot(J.T, J)  # hessian\n            b = np.dot(J.T, r)  # residuals\n\n            # solve Hx = -b for x\n            R = slin.cholesky(H)  # H = R' R\n            y = slin.solve_triangular(R.T, -b)\n            x = slin.solve_triangular(R, y)\n\n            self.update_on_manifold(x)\n\n        return self.z.reshape((-1, 7))\n\n\nclass PoseGraphFC:\n    def __init__(self):\n        \"\"\"\n        implements pose graph optimization from\n        \"Hybrid Hessians for Optimization of Pose Graphs\" - Y. LeCun et al\n        and \"A Tutorial on Graph-Based SLAM\" - W. Burgard et al\n        fully connected version\n        \"\"\"\n        self.N = 0\n        self.z = np.zeros((0, 0))\n\n    def jacobian(self, L_ax, L_aq, L_rx, L_rq):\n        # 6 because updates for rotation are on manifold\n        J = np.zeros((0, 6*self.N))\n\n        # unary constraints\n        for i in range(self.N):\n            # translation constraint\n            jt = np.zeros((3, J.shape[1]))\n            jt[:, 6*i: 6*i+3] = np.eye(3)\n            J = np.vstack((J, np.dot(L_ax, jt)))\n\n            # rotation constraint\n            jr = np.zeros((4, J.shape[1]))\n            jr[:, 6*i+3: 6*i+6] = m_rot(x=self.z[7*i+3: 7*i+7])\n            J = np.vstack((J, np.dot(L_aq, jr)))\n\n        # pairwise constraints\n        for i in range(self.N):\n            for j in range(i+1, self.N):\n                # translation constraint\n                jt = np.zeros((3, J.shape[1]))\n                dt = dqstq_t(q=self.z[7*i+3: 7*i+7])\n                # dt = np.eye(3)\n                jt[:, 6*i: 6*i+3] = -dt\n                jt[:, 6*j: 6*j+3] = dt\n                # m = m_rot(x=self.z[7*i+3 : 7*i+7])\n                # a = dqstq_q(q=self.z[7*i+3 : 7*i+7],\n                #             t=self.z[7*(i+1) : 7*(i+1)+3]-self.z[7*i : 7*i+3])\n                # jt[:, 6*i+3 : 6*i+6] = np.dot(a, m)\n                J = np.vstack((J, np.dot(L_rx, jt)))\n\n                # rotation constraint\n                jr = np.zeros((4, J.shape[1]))\n                m = m_rot(x=self.z[7*i+3: 7*i+7])\n                a = dpsq_p(q=self.z[7*j+3: 7*j+7])\n                jr[:, 6*i+3: 6*i+6] = np.dot(a, m)\n                m = m_rot(x=self.z[7*j+3: 7*j+7])\n                b = dpsq_q(p=self.z[7*i+3: 7*i+7])\n                jr[:, 6*j+3: 6*j+6] = np.dot(b, m)\n                J = np.vstack((J, np.dot(L_rq, jr)))\n\n        return J\n\n    def residuals(self, poses, vos, L_ax, L_aq, L_rx, L_rq):\n        \"\"\"\n        computes the residuals\n        :param poses: N x 7\n        :param vos: (N-1) x 7\n        :param L_ax: 3 x 3\n        :param L_aq: 4 x 4\n        :param L_rx: 3 x 3\n        :param L_rq: 4 x 4\n        :return: \n        \"\"\"\n        r = np.zeros((0, 1))\n\n        # unary residuals\n        L = np.zeros((7, 7))\n        L[:3, :3] = L_ax\n        L[3:, 3:] = L_aq\n        for i in range(self.N):\n            rr = self.z[7*i: 7*(i+1)] - np.reshape(poses[i], (-1, 1))\n            r = np.vstack((r, np.dot(L, rr)))\n\n        # pairwise residuals\n        k = 0\n        for i in range(self.N):\n            for j in range(i+1, self.N):\n                # translation residual\n                v = self.z[7*j:7*j+3, 0]-self.z[7*i:7*i+3, 0]\n                q = txq.qinverse(self.z[7*i+3:7*i+7, 0])\n                rt = txq.rotate_vector(v, q)\n                rt = rt[:, np.newaxis] - vos[k, :3].reshape((-1, 1))\n                # rt = self.z[7*(i+1) : 7*(i+1)+3] - self.z[7*i : 7*i+3] - \\\n                #     vos[i, :3].reshape((-1, 1))\n                r = np.vstack((r, np.dot(L_rx, rt)))\n\n                # rotation residual\n                q0 = self.z[7*i+3: 7*i+7].squeeze()\n                q1 = self.z[7*j+3: 7*j+7].squeeze()\n                qvo = txq.qmult(txq.qinverse(q0), q1).reshape((-1, 1))\n                rq = qvo - vos[k, 3:].reshape((-1, 1))\n                r = np.vstack((r, np.dot(L_rq, rq)))\n                k += 1\n\n        return r\n\n    def update_on_manifold(self, x):\n        \"\"\"\n        Updates the state vector on manifold\n        :param x: manifold increment, column vector\n        :return: \n        \"\"\"\n        for i in range(self.N):\n            # update translation\n            t = x[6*i: 6*i+3]\n            self.z[7*i: 7*i+3] += t\n\n            # update rotation\n            qm = x[6*i+3: 6*i+6]  # quaternion on the manifold\n            dq = np.zeros(4)\n            # method in Burgard paper\n            # dq[1:] = qm.squeeze()\n            # dq[0] = math.sqrt(1 - sum(np.square(qm)))  # incremental quaternion\n            # method of exponential map\n            n = np.linalg.norm(qm)\n            dq[0] = math.cos(n)\n            dq[1:] = np.sinc(n/np.pi) * qm.squeeze()\n            q = self.z[7*i+3: 7*i+7].squeeze()\n            q = txq.qmult(q, dq).reshape((-1, 1))\n            self.z[7*i+3: 7*i+7] = q\n\n    def optimize(self, poses, vos, sax=1, saq=1, srx=1, srq=1, n_iters=10):\n        \"\"\"\n        run PGO, with init = poses\n        :param poses:\n        :param vos:\n        :param sax: sigma for absolute translation\n        :param saq: sigma for absolute rotation\n        :param srx: sigma for relative translation\n        :param srq: sigma for relative rotation\n        :param n_iters:\n        :return:\n        \"\"\"\n        self.N = len(poses)\n        # init state vector with the predicted poses\n        self.z = np.reshape(poses.copy(), (-1, 1))\n\n        # construct the information matrices\n        L_ax = np.linalg.cholesky(np.eye(3) / sax)\n        L_aq = np.linalg.cholesky(np.eye(4) / saq)\n        L_rx = np.linalg.cholesky(np.eye(3) / srx)\n        L_rq = np.linalg.cholesky(np.eye(4) / srq)\n\n        for n_iter in range(n_iters):\n            J = self.jacobian(L_ax.T, L_aq.T, L_rx.T, L_rq.T)\n            r = self.residuals(poses.copy(), vos.copy(), L_ax.T, L_aq.T, L_rx.T,\n                               L_rq.T)\n            H = np.dot(J.T, J)  # hessian\n            b = np.dot(J.T, r)  # residuals\n\n            # solve Hx = -b for x\n            R = slin.cholesky(H)  # H = R' R\n            y = slin.solve_triangular(R.T, -b)\n            x = slin.solve_triangular(R, y)\n\n            self.update_on_manifold(x)\n\n        return self.z.reshape((-1, 7))\n\n\ndef optimize_poses(pred_poses, vos=None, fc_vos=False, target_poses=None,\n                   sax=1, saq=1, srx=1, srq=1):\n    \"\"\"\n    optimizes poses using either the VOs or the target poses (calculates VOs\n    from them)\n    :param pred_poses: N x 7\n    :param vos: (N-1) x 7\n    :param fc_vos: whether to use relative transforms between all frames in a fully\n    connected manner, not just consecutive frames\n    :param target_poses: N x 7\n    :param: sax: covariance of pose translation (1 number)\n    :param: saq: covariance of pose rotation (1 number)\n    :param: srx: covariance of VO translation (1 number)\n    :param: srq: covariance of VO rotation (1 number)\n    :return:\n    \"\"\"\n    pgo = PoseGraphFC() if fc_vos else PoseGraph()\n    if vos is None:\n        if target_poses is not None:\n            # calculate the VOs (in the pred_poses frame)\n            vos = np.zeros((len(target_poses)-1, 7))\n            for i in range(len(vos)):\n                vos[i, :3] = target_poses[i+1, :3] - target_poses[i, :3]\n                q0 = target_poses[i, 3:]\n                q1 = target_poses[i+1, 3:]\n                vos[i, 3:] = txq.qmult(txq.qinverse(q0), q1)\n        else:\n            print('Specify either VO or target poses')\n            return None\n    optim_poses = pgo.optimize(poses=pred_poses, vos=vos, sax=sax, saq=saq,\n                               srx=srx, srq=srq)\n    return optim_poses\n\n\ndef align_3d_pts(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (3xn)\n    x2 -- second trajectory (3xn)\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    # x2a = s * np.dot(R, x1-t)\n    # error = x2a - x2\n\n    return R, t, s\n\n\ndef align_2d_pts(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (2xn)\n    x2 -- second trajectory (2xn)\n\n    Output:\n    R -- rotation matrix (2x2)\n    t -- translation vector (2x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((2, 2))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(2)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[1, 1] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    # x2a = s * np.dot(R, x1-t)\n    # error = x2a - x2\n\n    return R, t, s\n\n\ndef align_3d_pts_noscale(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (3xn)\n    x2 -- second trajectory (3xn)\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    # s = np.asscalar(np.sqrt(r2/r1))\n    s = 1\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    # x2a = s * np.dot(R, x1-t)\n    # error = x2a - x2\n\n    return R, t, s\n\n\ndef align_2d_pts_noscale(x1, x2):\n    \"\"\"Align two sets of 3d points using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(x1-t) = x2\n\n    Input:\n    x1 -- first trajectory (2xn)\n    x2 -- second trajectory (2xn)\n\n    Output:\n    R -- rotation matrix (2x2)\n    t -- translation vector (2x1)\n    s -- scale (1x1)\n    written by Jinwei Gu\n    \"\"\"\n    x1c = x1.mean(1, keepdims=True)\n    x2c = x2.mean(1, keepdims=True)\n\n    x1_zerocentered = x1 - x1c\n    x2_zerocentered = x2 - x2c\n\n    W = np.zeros((2, 2))\n    r1 = 0\n    r2 = 0\n    for i in range(x1.shape[1]):\n        a = x1_zerocentered[:, i]\n        b = x2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    # s = np.asscalar(np.sqrt(r2/r1))\n    s = 1\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(2)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[1, 1] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = x1c - (1/s) * np.dot(R.transpose(), x2c)\n\n    # ---- align ----\n    # x2a = s * np.dot(R, x1-t)\n    # error = x2a - x2\n\n    return R, t, s\n\n\ndef align_camera_poses(o1, o2, R1, R2, use_rotation_constraint=True):\n    \"\"\"Align two sets of camera poses (R1,o1/R2,o2) using the method of Horn (closed-form).\n\n    Find optimal s, R, t, such that\n\n            s*R*(o1-t) = o2   (1)\n\n            R*R1 = R2         (2)\n\n    where R1/R2 are the camera-to-world matrices, o1/o2 are the center\n    of the cameras.\n\n    Input:\n    o1 -- camera centers (3xn)\n    o2 -- camera centers (3xn)\n    R1 -- camera poses (camera-to-world matrices) (nx3x3)\n    R2 -- camera poses (camera-to-world matrices) (nx3x3)\n    use_rotation_constraint -- if False, uses only Eq(1) to solve.\n\n    Output:\n    R -- rotation matrix (3x3)\n    t -- translation vector (3x1)\n    s -- scale (1x1)\n\n    Note, when use_rotation_constraint=False, it is the same problem as\n    above, i.e., to align two sets of 3D points.\n\n    When use_rotation_constraint=True, we note Eq(2) is the same\n    equation as Eq(1), after we zero-center and remove the scale. So, we\n    can use the same approach (SVD).\n    written by Jinwei Gu\n    \"\"\"\n    if not use_rotation_constraint:\n        return align_3d_pts(o1, o2)\n\n    o1c = o1.mean(1, keepdims=True)\n    o2c = o2.mean(1, keepdims=True)\n    o1_zerocentered = o1 - o1c\n    o2_zerocentered = o2 - o2c\n\n    W = np.zeros((3, 3))\n    r1 = 0\n    r2 = 0\n    for i in range(o1.shape[1]):\n        a = o1_zerocentered[:, i]\n        b = o2_zerocentered[:, i]\n        W += np.outer(b, a)\n        r1 += np.dot(a.T, a)\n        r2 += np.dot(b.T, b)\n\n    s = np.asscalar(np.sqrt(r2/r1))\n\n    # add rotation constraints\n    for i in range(o1.shape[1]):\n        d1 = np.squeeze(R1[i, :, :])\n        d2 = np.squeeze(R2[i, :, :])\n        for c in range(3):\n            a = d1[:, c]\n            b = d2[:, c]\n            W += np.outer(b, a)\n\n    U, d, Vh = np.linalg.svd(W)\n    S = np.eye(3)\n    if np.linalg.det(np.dot(U, Vh)) < 0:\n        S[2, 2] = -1\n    R = np.dot(U, np.dot(S, Vh))\n    t = o1c - (1/s) * np.dot(R.transpose(), o2c)\n\n    # ---- align ----\n    # o2a = s * np.dot(R, o1-t)\n    # R2a = np.dot(R, R1)\n\n    return R, t, s\n\n\ndef test_align_3d_pts():\n    import transforms3d.euler as txe\n    N = 10\n    x1 = np.random.rand(3, N)\n\n    noise = np.random.rand(3, N)*0.01\n\n    s = np.random.rand()\n    t = np.random.rand(3, 1)\n    R = txe.euler2mat(np.random.rand(), np.random.rand(), np.random.rand())\n    R = R[:3, :3]\n\n    x2 = s*np.dot(R, x1-t) + noise\n\n    Re, te, se = align_3d_pts(x1, x2)\n\n    print('scale ', s, se)\n    print('rotation matrx ', R, Re)\n    print('translation ', t, te)\n\n\ndef test_align_camera_poses():\n    import transforms3d.euler as txe\n\n    N = 10\n    o1 = np.random.rand(3, N)\n\n    noise = np.random.rand(3, N)*0.01\n\n    s = np.random.rand()\n    t = np.random.rand(3, 1)\n    R = txe.euler2mat(np.random.rand(), np.random.rand(), np.random.rand())\n    R = R[:3, :3]\n\n    o2 = s*np.dot(R, o1-t) + noise\n\n    R1 = np.zeros((N, 3, 3))\n    R2 = np.zeros((N, 3, 3))\n    for i in range(N):\n        Ri = txe.euler2mat(\n            np.random.rand(), np.random.rand(), np.random.rand())\n        R1[i, :, :] = Ri[:3, :3]\n        R2[i, :, :] = np.dot(R, Ri[:3, :3])\n\n    Re1, te1, se1 = align_camera_poses(o1, o2, R1, R2, False)\n    Re2, te2, se2 = align_camera_poses(o1, o2, R1, R2, True)\n\n    print('scale ', s, se1, se2)\n    print('rotation matrx ', R, Re1, Re2)\n    print('translation ', t, te1, te2)\n\n\ndef pgo_test_poses():\n    \"\"\"\n    generates test poses and vos for the various PGO implementations\n    :return:\n    \"\"\"\n    poses = np.zeros((3, 7))\n    for i in range(poses.shape[0]):\n        poses[i, :3] = i\n        angle = math.radians(10*i)\n        R = txe.euler2mat(angle, angle, angle)\n        q = txq.mat2quat(R)\n        poses[i, 3:] = q\n\n    vos = np.zeros((poses.shape[0]-1, 7))\n    for i in range(vos.shape[0]):\n        vos[i, 0] = 1.5\n        vos[i, 1] = 0.5\n        vos[i, 2] = 1.0\n        R = txe.euler2mat(math.radians(15), math.radians(10), math.radians(5))\n        q = txq.mat2quat(R)\n        vos[i, 3:] = q\n\n    return poses, vos\n\n\ndef pgo_test_poses1():\n    poses = np.zeros((3, 7))\n    R = txe.euler2mat(0, 0, np.deg2rad(45))\n    q = txq.mat2quat(R)\n    poses[:, 3:] = q\n    for i in range(len(poses)):\n        poses[i, :3] = np.asarray([i, i, 0])\n\n    pt = np.zeros((len(poses), 6))\n    pt[:, :3] = poses[:, :3]\n    for i, p in enumerate(poses):\n        pt[i, 3:] = qlog(p[3:])\n    pt = torch.from_numpy(pt.astype(np.float32))\n    vost = calc_vos_safe_fc(pt.unsqueeze(0))[0].numpy()\n    vos = np.zeros((len(vost), 7))\n    vos[:, :3] = vost[:, :3]\n    for i, p in enumerate(vost):\n        vos[i, 3:] = qexp(p[3:])\n\n    # perturbation\n    vos[0, 0] = np.sqrt(2) - 0.5\n    vos[1, 0] = np.sqrt(2) - 0.5\n\n    return poses, vos\n\n\ndef print_poses(poses):\n    print('translations')\n    print(poses[:, :3])\n    print('euler')\n    for i in range(poses.shape[0]):\n        a = txe.mat2euler(txq.quat2mat(poses[i, 3:]))\n        print([np.rad2deg(aa) for aa in a])\n\n\ndef test_pgo():\n    \"\"\"\n    Tests the full pose graph optimization implementation\n    :return: bool\n    \"\"\"\n    pred_poses, vos = pgo_test_poses1()\n    print('pred poses')\n    print_poses(pred_poses)\n    print('vos')\n    print_poses(vos)\n\n    pgo = PoseGraph()\n    optimized_poses = pgo.optimize(pred_poses, vos)\n\n    print('optimized')\n    print_poses(optimized_poses)\n\n\ndef test_pose_utils():\n    \"\"\"\n    Tests the pose utils\n    :return: \n    \"\"\"\n    TEST_COMPOSE = True\n    TEST_INV = True\n\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n\n    if TEST_COMPOSE:\n        print('Testing pose composing...')\n        R1 = txe.euler2mat(ra(1), ra(1), ra(1))\n        t1 = np.random.rand(3)\n        R2 = txe.euler2mat(ra(1), ra(1), ra(1))\n        t2 = np.random.rand(3)\n\n        # homogeneous matrix method\n        R = np.dot(R1, R2)\n        t = t1 + np.dot(R1, t2)\n        print('From homogeneous matrices, t = ')\n        print(t)\n        print('R = ')\n        print(R)\n\n        # quaternion method\n        q1 = txq.mat2quat(R1)\n        q2 = txq.mat2quat(R2)\n\n        p1 = torch.cat((torch.from_numpy(t1), torch.from_numpy(q1)))\n        p2 = torch.cat((torch.from_numpy(t2), torch.from_numpy(q2)))\n        p = compose_pose_quaternion(\n            torch.unsqueeze(p1, 0), torch.unsqueeze(p2, 0))\n        t = p[:, :3].numpy().squeeze()\n        q = p[:, 3:].numpy().squeeze()\n        print('From quaternions, t = ')\n        print(t)\n        print('R = ')\n        print(txe.quat2mat(q))\n\n    if TEST_INV:\n        print('Testing pose inversion...')\n        R = txe.euler2mat(ra(1), ra(1), ra(1))\n        t = np.random.rand(3)\n        T = np.eye(4)\n        T[:3, :3] = R\n        T[:3, -1] = t\n\n        q = txq.mat2quat(R)\n        p = torch.cat((torch.from_numpy(t), torch.from_numpy(q)))\n        pinv = invert_pose_quaternion(torch.unsqueeze(p, 0))\n        tinv, qinv = pinv[:, :3], pinv[:, 3:]\n        Rinv = txq.quat2mat(qinv.numpy().squeeze())\n        Tinv = np.eye(4)\n        Tinv[:3, :3] = Rinv\n        Tinv[:3, -1] = tinv.numpy().squeeze()\n        print('T * T^(-1) = ')\n        print(np.dot(T, Tinv))\n\n\ndef test_q_error():\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n    # rotation along x axis\n    a1 = ra(1)\n    a2 = ra(1)\n    q1 = txq.mat2quat(txe.euler2mat(a1, 0, 0))\n    q2 = txq.mat2quat(txe.euler2mat(a2, 0, 0))\n    a1 = np.rad2deg(a1)\n    a2 = np.rad2deg(a2)\n    print('Angles: {:f}, {:f}, difference = {:f}'.format(a1, a2, a1-a2))\n    print('Error: {:f}'.format(quaternion_angular_error(q1, q2)))\n\n\ndef test_log_q_error():\n    def ra(_): return np.random.uniform(0, 2*math.pi)\n    # rotation along x axis\n    a1 = ra(1)\n    a2 = ra(1)\n    q1 = txq.mat2quat(txe.euler2mat(0, a1, 0))\n    q2 = txq.mat2quat(txe.euler2mat(0, a2, 0))\n    # apply log map\n    q1 = np.arccos(q1[0]) * q1[1:] / np.linalg.norm(q1[1:])\n    q2 = np.arccos(q2[0]) * q2[1:] / np.linalg.norm(q2[1:])\n    a1 = np.rad2deg(a1)\n    a2 = np.rad2deg(a2)\n    print('Angles: {:f}, {:f}, difference = {:f}'.format(a1, a2, a1-a2))\n    print('Error: {:f}'.format(log_quaternion_angular_error(q1, q2)))\n\n\nif __name__ == '__main__':\n    test_pgo()\n    # test_dumb_pgo()\n    # test_align_camera_poses()\n    # test_q_error()\n    # test_log_q_error()\n"
  },
  {
    "path": "utils/progress_bar.py",
    "content": "import contextlib\r\nimport enum\r\nimport math\r\nimport time\r\n\r\nimport numpy as np\r\n\r\n\r\ndef progress_str(val, *args, width=20, with_ptg=True):\r\n    val = max(0., min(val, 1.))\r\n    assert width > 1\r\n    pos = round(width * val) - 1\r\n    if with_ptg is True:\r\n        log = '[{}%]'.format(max_point_str(val * 100.0, 4))\r\n    log += '['\r\n    for i in range(width):\r\n        if i < pos:\r\n            log += '='\r\n        elif i == pos:\r\n            log += '>'\r\n        else:\r\n            log += '.'\r\n    log += ']'\r\n    for arg in args:\r\n        log += '[{}]'.format(arg)\r\n    return log\r\n\r\n\r\ndef second_to_time_str(second, omit_hours_if_possible=True):\r\n    second = int(second)\r\n    m, s = divmod(second, 60)\r\n    h, m = divmod(m, 60)\r\n    if omit_hours_if_possible:\r\n        if h == 0:\r\n            return '{:02d}:{:02d}'.format(m, s)\r\n    return '{:02d}:{:02d}:{:02d}'.format(h, m, s)\r\n\r\n\r\ndef progress_bar_iter(task_list, width=20, with_ptg=True, step_time_average=50, name=None):\r\n    total_step = len(task_list)\r\n    step_times = []\r\n    start_time = 0.0\r\n    name = '' if name is None else f\"[{name}]\"\r\n    for i, task in enumerate(task_list):\r\n        t = time.time()\r\n        yield task\r\n        step_times.append(time.time() - t)\r\n        start_time += step_times[-1]\r\n        start_time_str = second_to_time_str(start_time)\r\n        average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6\r\n        speed_str = \"{:.2f}it/s\".format(1 / average_step_time)\r\n        remain_time = (total_step - i) * average_step_time\r\n        remain_time_str = second_to_time_str(remain_time)\r\n        time_str = start_time_str + '>' + remain_time_str\r\n        prog_str = progress_str(\r\n            (i + 1) / total_step,\r\n            speed_str,\r\n            time_str,\r\n            width=width,\r\n            with_ptg=with_ptg)\r\n        print(name + prog_str + '   ', end='\\r', flush=True)\r\n    print(\"\")\r\n\r\n\r\nlist_bar = progress_bar_iter\r\n\r\ndef enumerate_bar(task_list, width=20, with_ptg=True, step_time_average=50, name=None):\r\n    total_step = len(task_list)\r\n    step_times = []\r\n    start_time = 0.0\r\n    name = '' if name is None else f\"[{name}]\"\r\n    for i, task in enumerate(task_list):\r\n        t = time.time()\r\n        yield i, task\r\n        step_times.append(time.time() - t)\r\n        start_time += step_times[-1]\r\n        start_time_str = second_to_time_str(start_time)\r\n        average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6\r\n        speed_str = \"{:.2f}it/s\".format(1 / average_step_time)\r\n        remain_time = (total_step - i) * average_step_time\r\n        remain_time_str = second_to_time_str(remain_time)\r\n        time_str = start_time_str + '>' + remain_time_str\r\n        prog_str = progress_str(\r\n            (i + 1) / total_step,\r\n            speed_str,\r\n            time_str,\r\n            width=width,\r\n            with_ptg=with_ptg)\r\n        print(name + prog_str + '   ', end='\\r', flush=True)\r\n    print(\"\")\r\n\r\n\r\ndef max_point_str(val, max_point):\r\n    positive = bool(val >= 0.0)\r\n    val = np.abs(val)\r\n    if val == 0:\r\n        point = 1\r\n    else:\r\n        point = max(int(np.log10(val)), 0) + 1\r\n    fmt = \"{:.\" + str(max(max_point - point, 0)) + \"f}\"\r\n    if positive is True:\r\n        return fmt.format(val)\r\n    else:\r\n        return fmt.format(-val)\r\n\r\n\r\nclass Unit(enum.Enum):\r\n    Iter = 'iter'\r\n    Byte = 'byte'\r\n\r\n\r\ndef convert_size(size_bytes):\r\n    # from https://stackoverflow.com/questions/5194057/better-way-to-convert-file-sizes-in-python\r\n    if size_bytes == 0:\r\n        return \"0B\"\r\n    size_name = (\"B\", \"KB\", \"MB\", \"GB\", \"TB\", \"PB\", \"EB\", \"ZB\", \"YB\")\r\n    i = int(math.floor(math.log(size_bytes, 1024)))\r\n    p = math.pow(1024, i)\r\n    s = round(size_bytes / p, 2)\r\n    return s, size_name[i]\r\n\r\n\r\nclass ProgressBar:\r\n    def __init__(self,\r\n                 width=20,\r\n                 with_ptg=True,\r\n                 step_time_average=50,\r\n                 speed_unit=Unit.Iter):\r\n        self._width = width\r\n        self._with_ptg = with_ptg\r\n        self._step_time_average = step_time_average\r\n        self._step_times = []\r\n        self._start_time = 0.0\r\n        self._total_size = None\r\n        self._speed_unit = speed_unit\r\n\r\n    def start(self, total_size):\r\n        self._start = True\r\n        self._step_times = []\r\n        self._finished_sizes = []\r\n        self._time_elapsed = 0.0\r\n        self._current_time = time.time()\r\n        self._total_size = total_size\r\n        self._progress = 0\r\n\r\n    def print_bar(self, finished_size=1, pre_string=None, post_string=None):\r\n        self._step_times.append(time.time() - self._current_time)\r\n        self._finished_sizes.append(finished_size)\r\n        self._time_elapsed += self._step_times[-1]\r\n        start_time_str = second_to_time_str(self._time_elapsed)\r\n        time_per_size = np.array(self._step_times[-self._step_time_average:])\r\n        time_per_size /= np.array(\r\n            self._finished_sizes[-self._step_time_average:])\r\n        average_step_time = np.mean(time_per_size) + 1e-6\r\n        if self._speed_unit == Unit.Iter:\r\n            speed_str = \"{:.2f}it/s\".format(1 / average_step_time)\r\n        elif self._speed_unit == Unit.Byte:\r\n            size, size_unit = convert_size(1 / average_step_time)\r\n            speed_str = \"{:.2f}{}/s\".format(size, size_unit)\r\n        else:\r\n            raise ValueError(\"unknown speed unit\")\r\n        remain_time = (self._total_size - self._progress) * average_step_time\r\n        remain_time_str = second_to_time_str(remain_time)\r\n        time_str = start_time_str + '>' + remain_time_str\r\n        prog_str = progress_str(\r\n            (self._progress + 1) / self._total_size,\r\n            speed_str,\r\n            time_str,\r\n            width=self._width,\r\n            with_ptg=self._with_ptg)\r\n        self._progress += finished_size\r\n        if pre_string is not None:\r\n            prog_str = pre_string + prog_str\r\n        if post_string is not None:\r\n            prog_str += post_string\r\n        if self._progress >= self._total_size:\r\n            print(prog_str + '   ',flush=True)\r\n        else:\r\n            print(prog_str + '   ', end='\\r', flush=True)\r\n        self._current_time = time.time()\r\n"
  },
  {
    "path": "utils/rand_utils.py",
    "content": "import numpy as np  \n\ndef truncated_normal(u, sigma, min, max, shape=None):\n    \"\"\" Generate data following truncated normal distribution\n\n    Args:\n        u ([type]): mean\n        sigma ([type]): var=sigma^2\n        min ([type]): lower bound of the truncating range\n        max ([type]): higher bound of the truncating range\n        shape ([type], optional): [description]. Defaults to None.\n    \"\"\"\n\n    val = min-1 \n    while val<min or val>max: #iterative sampling until the first qualified data emerge\n        if shape is not None:\n            val = sigma*np.random.randn(shape)+u\n        else:\n            val = sigma*np.random.randn()+u\n\n    assert val != min-1 \n\n    return val\n        \n"
  },
  {
    "path": "utils/singleton.py",
    "content": "# import h5py\n\nclass Singleton(type):\n    _instances = {}\n    def __call__(cls, *args, **kwargs):\n        if cls not in cls._instances:\n            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)\n        return cls._instances[cls]\n"
  },
  {
    "path": "utils/timer.py",
    "content": "import time \nfrom contextlib import contextmanager\n\n\n@contextmanager\ndef simple_timer(name=''):\n    t = time.time()\n    yield \n    print(f\"{name} exec time: {time.time() - t}\")\n\n\ndef singleton(class_):\n    instances = {}\n    def getinstance(*args, **kwargs):\n        if class_ not in instances:\n            instances[class_] = class_(*args, **kwargs)\n        return instances[class_]\n    return getinstance\n\n@singleton\nclass timming(object):\n    def __init__(self):\n        super().__init__()\n        self.items={}\n\n    def start(self, item_name):\n        if item_name not in self.items:\n            self.items[item_name]={\n                \"end_cnt\": 0,\n                \"avg\": 0,\n                \"start_t\": time.time(),\n                \"is_finished\": False\n            }\n        else:\n            assert self.items[item_name][\"is_finished\"]==True\n            self.items[item_name].update(\n                {   \n                    # \"start_cnt\":  self.items[item_name][\"start_cnt\"]+1,\n                    \"end_cnt\": self.items[item_name][\"end_cnt\"], #unchanged\n                    \"avg\": self.items[item_name][\"avg\"], #unchanged \n                    \"start_t\": time.time(),\n                    \"is_finished\": False\n                }\n            )\n\n    def end(self, item_name):\n        assert item_name in self.items\n        t=time.time()\n        interval=t-self.items[item_name]['start_t']\n        self.items[item_name].update(\n            {   \n                \"end_cnt\":  self.items[item_name][\"end_cnt\"]+1,\n                \"avg\": (self.items[item_name][\"avg\"]*self.items[item_name][\"end_cnt\"]+interval)/(self.items[item_name][\"end_cnt\"]+1) ,\n                \"is_finished\": True\n            }\n        )\n    def summarize(self):\n        for k in self.items:\n            print(f\"Average time of {k} = {self.items[k]['avg']*1000} ms\", f\"(averaged on {self.items[k]['end_cnt']} testing cycles.\") \n"
  },
  {
    "path": "utils/util.py",
    "content": "import torch\nimport numpy as np\nimport collections\n\n\n\ndef freeze_params(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    remain_params = []\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                continue\n        remain_params.append(p)\n    return remain_params\n\n\ndef freeze_params_v2(params: dict, include: str = None, exclude: str = None):\n    assert isinstance(params, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    for k, p in params.items():\n        if include_re is not None:\n            if include_re.match(k) is not None:\n                p.requires_grad = False\n        if exclude_re is not None:\n            if exclude_re.match(k) is None:\n                p.requires_grad = False\n\n\ndef filter_param_dict(state_dict: dict, include: str = None, exclude: str = None):\n    assert isinstance(state_dict, dict)\n    include_re = None\n    if include is not None:\n        include_re = re.compile(include)\n    exclude_re = None\n    if exclude is not None:\n        exclude_re = re.compile(exclude)\n    res_dict = {}\n    for k, p in state_dict.items():\n        if include_re is not None:\n            if include_re.match(k) is None:\n                continue\n        if exclude_re is not None:\n            if exclude_re.match(k) is not None:\n                continue\n        res_dict[k] = p\n    return res_dict\n\ndef modify_parameter_name_with_map(state_dict, parameteter_name_map=None):\n    if parameteter_name_map is None:\n        return state_dict\n    for old,new in parameteter_name_map:\n        for key in list(state_dict.keys()) :\n            if old in key:\n                new_key=key.replace(old, new)\n                state_dict[new_key] = state_dict.pop(key)\n    return state_dict\n\ndef load_pretrained_model_map_func(state_dict,parameteter_name_map = None, include:str=None, exclude:str=None):\n    state_dict = filter_param_dict(state_dict, include, exclude)\n    state_dict = modify_parameter_name_with_map(state_dict, parameteter_name_map)\n    \n\n\ndef list_recursive_op(input_list, op):\n    assert isinstance(input_list, list)\n\n    for i, v in enumerate(input_list):\n        if isinstance(v, list):\n            input_list[i] = list_recursive_op(v, op)\n        elif isinstance(v, dict):\n            input_list[i] = dict_recursive_op(v, op)\n        else:\n            input_list[i] = op(v)\n\n    return input_list\n\n\ndef dict_recursive_op(input_dict, op):\n    assert isinstance(input_dict, dict)\n\n    for k, v in input_dict.items():\n        if isinstance(v, dict):\n            input_dict[k] = dict_recursive_op(v, op)\n        elif isinstance(v, (list,tuple) ):\n            input_dict[k] = list_recursive_op(v, op)\n        else:\n            input_dict[k] = op(v)\n\n    return input_dict\n\n"
  },
  {
    "path": "utils/visualize.py",
    "content": "import numpy as np \nimport cv2 \nimport copy\n\ndef vis_pointclouds_cv2(pc, K, win_size, init_transform=None, color=None, img=None):\n    '''\n    pc: input point cloud of shape Nx3\n    K: camera intrinsic of shape 3x3\n    win_size: visualization window size (Wx,Wy)\n    '''\n    x = (K@pc.T).T  #Nx3\n    \n    x = x/x[:,-1:]\n    x = x.astype(np.int32)\n    x[:,0] = np.where((x[:,0]<0) | (x[:,0]>=win_size[1]), np.zeros_like(x[:,0]), x[:,0])\n    x[:,1] = np.where((x[:,1]<0) | (x[:,1]>=win_size[0]), np.zeros_like(x[:,1]), x[:,1])\n\n    if img is None:\n        img=np.zeros(list(win_size)+[3], dtype=np.uint8)\n    if color is None:\n        # color = [255, 255, 0]\n        color = [255, 255, 0]\n\n    img[x[:, 1], x[:, 0]] = color\n    \n    img[x[pc[:,-1]<0][:,1], x[pc[:,-1]<0][:,0] ] = [255,0,0]\n\n    return img\n   \ndef vis_2d_keypoints_cv2(img, keypoints, color=None):\n    '''\n    img: input point cloud of shape HxWx3\n    keypoints: Nx2 , (x,y)\n    '''\n\n    keypoints = np.around(keypoints).astype(np.int32)\n    img=copy.copy(img)\n    if color is None:\n        color = [255, 255, 0]\n    img[keypoints[:,1], keypoints[:,0]] = color\n\n    return img\n   \n\ndef get_model_corners(model):\n    min_x, max_x = np.min(model[:, 0]), np.max(model[:, 0])\n    min_y, max_y = np.min(model[:, 1]), np.max(model[:, 1])\n    min_z, max_z = np.min(model[:, 2]), np.max(model[:, 2])\n    corners_3d = np.array([\n        [min_x, min_y, min_z],\n        [min_x, min_y, max_z],\n        [min_x, max_y, min_z],\n        [min_x, max_y, max_z],\n        [max_x, min_y, min_z],\n        [max_x, min_y, max_z],\n        [max_x, max_y, min_z],\n        [max_x, max_y, max_z],\n    ])\n    return corners_3d\n\ndef vis_pose_box(RT,K, model, background=None,fig=None, ax=None, title='', label='', x_label='x', y_label='y', color='g', dot='-'):\n    \n    if fig is None:\n        if background is not None:\n            # dpi = float(matplotlib.rcParams['figure.dpi'])\n            dpi=100\n            # print(float(matplotlib.rcParams['figure.dpi']))\n            fig = plt.figure(figsize=[s/dpi for s in background.shape[:2]], dpi=dpi )\n        else:\n            fig = plt.figure()\n\n    if ax is None:\n        ax=fig.gca()\n    # ax.set_axis_off()\n    corner_3d=get_model_corners(model)\n    corner_2d = project(corner_3d, K, RT)\n\n    if background is not None:\n        ax.imshow(background)\n    ax.add_patch(patches.Polygon(\n        xy=corner_2d[[0, 1, 3, 2, 0, 4, 6, 2]], fill=False, linewidth=1, edgecolor=color))\n    ax.add_patch(patches.Polygon(\n        xy=corner_2d[[5, 4, 6, 7, 5, 1, 3, 7]], fill=False, linewidth=1, edgecolor=color))  \n\n    # line,=ax.plot(x, y, dot, color=color, linewidth=1)    \n    # line.set_label(label)\n    # ax.set_title(title)\n    # ax.set_xlabel(x_label)\n    # ax.set_ylabel(y_label)\n    # ax.axis('equal')\n    # if label !='':\n    #     ax.legend()\n    return fig, ax"
  }
]