[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n__pycache__"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\">GeoMVSNet: Learning Multi-View Stereo With Geometry Perception (CVPR 2023)</h1>\n\n<div align=\"center\">\n    <a href=\"https://www.doublez.site\" target='_blank'>Zhe Zhang</a>, \n    <a href=\"https://github.com/prstrive\" target='_blank'>Rui Peng</a>, \n    <a href=\"https://yuhsihu.github.io\" target='_blank'>Yuxi Hu</a>, \n    <a href=\"https://www.ece.pku.edu.cn/info/1046/2147.htm\" target='_blank'>Ronggang Wang</a>*\n</div>\n\n<br />\n\n<div align=\"center\">\n    <a href=\"https://openaccess.thecvf.com//content/CVPR2023/html/Zhang_GeoMVSNet_Learning_Multi-View_Stereo_With_Geometry_Perception_CVPR_2023_paper.html\" target='_blank'><img src=\"https://img.shields.io/badge/CVPR-2023-9cf?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADUAAAAjCAMAAAADt7LEAAADAFBMVEVMaXF5mcqXsNZ0lceQq9OXsNaHpM9zlMeHo8+dtNh9nMypvt12lsiFos6mu9xxksZvkcZzlMdlicJxksZqjcRukMUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACreRF2AAAAEXRSTlMA1mzoi0mi9BQ1wiRqVgeN+nWbeCoAAAAJcEhZcwAALiMAAC4jAXilP3YAAAJASURBVDiNnZTbduIwDEW3ZJOEwBTo/P8ntnRCCQHb8jzkUhgIrDV6iRN7W0eXCP7HZFotJKiPtsjh7pDfHq5f68YPq1/tNksEFfmsD/9iVt/e0lNvsaxivw/vuV01zxUqQHWsU5w+RbPnDHjg7bBLgCqmOZt854euguRhtfBAt8ugshfxMedNd3l4fyvdta/FWjKaDwkCVF/FK309haGyH4Lp6J4TAMqyFsjNywzcUsEMlcXLk8sbhYUmpOkz8BYA9DhtVwJk0wSklQGnnjotgaFaTV05U+w0UrZ2pG8DkLUqNFUHKGqQV8OpAMSUJlVvzkxdJX0wU/mVsRN7qsuGbMYYkkWTj5OLAGZmcaSigkyRBMBGZ6vagQl3ppQKNjkMrQPZ9IrPGvF/1uPWd4yxKPpsXCpwqRprGwDr/7EquEiSoSlbCdOfp4hC3Ewtf+Xssol4K+8FooQPB243lbmPLABZIbXHB5QDLcG0dEOrG2XGVxJ0Z6iLcah1kHhNZVmAOQdFLVV5oQDLrS2LDIczULl8Sykg5iCm1XZ5XhYtoXPgki4TqhngmLk1B4RUuAxmWkiuzrlSsLoAtD1DH8OdL8Lp4BQwMzRz2DswN176MIcAp7JR5xVQ2SlrAwzcx92M+1EInJMcj6U67LNw4dKt+kjO/UNLFXHxR+F1k1USfe4AdJsB/znMxdXFE/2ieUhdmfxOIF9zY0FnKAN/nJ1WM5TtHSnNTqsZCjF1j/r2hcn7k2k65wuZq/BTW/knm38BWrgDGcRH1DMAAAAASUVORK5CYII=\"/></a>&nbsp;\n    <a href=\"https://openaccess.thecvf.com//content/CVPR2023/papers/Zhang_GeoMVSNet_Learning_Multi-View_Stereo_With_Geometry_Perception_CVPR_2023_paper.pdf\" target='_blank'><img src=\"https://img.shields.io/badge/Paper-PDF-f5cac3?logo=adobeacrobatreader&logoColor=red\"/></a>&nbsp;\n    <a href=\"https://openaccess.thecvf.com/content/CVPR2023/supplemental/Zhang_GeoMVSNet_Learning_Multi-View_CVPR_2023_supplemental.pdf\" target='_blank'><img src=\"https://img.shields.io/badge/Supp.-PDF-f5cac3?logo=adobeacrobatreader&logoColor=red\"/></a>&nbsp;\n    <a href=\"https://paperswithcode.com/sota/point-clouds-on-tanks-and-temples?p=geomvsnet-learning-multi-view-stereo-with\" target='_blank'><img src=\"https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/geomvsnet-learning-multi-view-stereo-with/point-clouds-on-tanks-and-temples\" /></a>\n</div>\n\n<br />\n\n<div align=\"center\">\n\n<a href=\"https://youtu.be/XqLDgJAZAKc\" target='_blank'><img src=\".github/imgs/geomvsnet-video-cover.png\" width=\"50%\" /></a><a href=\"https://youtu.be/dLyuFMz1tAk\" target='_blank'><img src=\".github/imgs/mvs-demo-video-cover.png\" width=\"50%\" /></a>\n\n</div>\n\n\n## 🔨 Setup\n\n### 1.1 Requirements\n\nUse the following commands to build the `conda` environment.\n\n```bash\nconda create -n geomvsnet python=3.8\nconda activate geomvsnet\npip install -r requirements.txt\n```\n\n### 1.2 Datasets\n\nDownload the following datasets and modify the corresponding local path in `scripts/data_path.sh`.\n\n#### DTU Dataset\n\n**Training data**. We use the same DTU training data as mentioned in MVSNet and CasMVSNet, please refer to [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depth raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) for data download. Optional, you should download the [Recitfied raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip) if you want to train the model in raw image resolution. Unzip and organize them as:\n\n```\ndtu/\n├── Cameras\n├── Depths\n├── Depths_raw\n├── Rectified\n└── Rectified_raw (optional)\n```\n\n**Testing data**. For convenience, we use the [DTU testing data](https://drive.google.com/file/d/1rX0EXlUL4prRxrRu2DgLJv2j7-tpUD4D/view?usp=sharing) processed by CVP-MVSNet. Also unzip and organize it as:\n\n```\ndtu-test/\n├── Cameras\n├── Depths\n└── Rectified\n```\n\n> Please note that the images and lighting here are consistent with the original dataset. \n\n#### BlendedMVS Dataset\n\nDownload the low image resolution version of [BlendedMVS dataset](https://drive.google.com/file/d/1ilxls-VJNvJnB7IaFj7P0ehMPr7ikRCb/view) and unzip it as:\n\n```\nblendedmvs/\n└── dataset_low_res\n    ├── ...\n    └── 5c34529873a8df509ae57b58\n```\n\n#### Tanks and Temples Dataset\n\nDownload the intermediate and advanced subsets of [Tanks and Temples dataset](https://drive.google.com/file/d/1YArOJaX9WVLJh4757uE8AEREYkgszrCo/view) and unzip them. If you want to use the short range version of camera parameters for `Intermediate` subset, unzip `short_range_caemeras_for_mvsnet.zip` and move `cam_[]` to the corresponding scenarios.\n\n```\ntnt/\n├── advanced\n│   ├── ...\n│   └── Temple\n│       ├── cams\n│       ├── images\n│       ├── pair.txt\n│       └── Temple.log\n└── intermediate\n    ├── ...\n    └── Train\n        ├── cams\n        ├── cams_train\n        ├── images\n        ├── pair.txt\n        └── Train.log\n```\n\n\n## 🚂 Training\n\nYou can train GeoMVSNet from scratch on DTU dataset and BlendedMVS dataset. After suitable setting and training, you can get the training checkpoints model in `checkpoints/[Dataset]/[THISNAME]`, and the following outputs lied in the folder:\n- `events.out.tfevents*`: you can use `tensorboard` to monitor the training process.\n- `model_[epoch].ckpt`: we save a checkpoint every `--save_freq`.\n- `train-[TIME].log`: logged the detailed training message, you can refer to appropiate indicators to judge the quality of training.\n\n### 2.1 DTU\n\nTo train GeoMVSNet on DTU dataset, you can refer to `scripts/dtu/train_dtu.sh`, specify `THISNAME`, `CUDA_VISIBLE_DEVICES`, `batch_size`, etc. to meet your demand. And run:\n\n```bash\nbash scripts/dtu/train_dtu.sh\n```\n\nThe default training strategy we provide is the *distributed* training mode. If you want to use the *general* training mode, you can refer to the following code. \n\n<details>\n<summary>general training script</summary>\n\n```bash\nCUDA_VISIBLE_DEVICES=0,1,2,3 python3 train.py ${@} \\\n    --which_dataset=\"dtu\" --epochs=16 --logdir=$LOG_DIR \\\n    --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \\\n    --trainlist=\"datasets/lists/dtu/train.txt\" --testlist=\"datasets/lists/dtu/test.txt\" \\\n    \\\n    --data_scale=\"mid\" --n_views=\"5\" --batch_size=16 --lr=0.025 --robust_train \\\n    --lrepochs=\"1,3,5,7,9,11,13,15:1.5\"\n```\n\n</details>\n\n> It should be noted that two different training strategies need to adjust the `batch_size` and `lr` parameters to achieve the best training results.\n\n\n### 2.2 BlendedMVS\n\nTo train GeoMVSNet on BlendedMVS dataset, you can refer to `scripts/bled/train_blend.sh`, and also specify `THISNAME`, `CUDA_VISIBLE_DEVICES`, `batch_size`, etc. to meet your demand. And run:\n\n```bash\nbash scripts/blend/train_blend.sh\n```\n\nBy default, we use `7` viewpoints as input for the BlendedMVS training. Similarly, you can choose to use the *distributed* training mode or the *general* one as mentioned in 2.1.\n\n## ⚗️ Testing\n\n### 3.1 DTU\n\nFor DTU testing, we use model trained on DTU training dataset. You can basically download our [DTU pretrained model](https://drive.google.com/file/d/147_UbjE87E-HB9sZ5yLDbckynH825nJd/view?usp=sharing) and put it into `checkpoints/dtu/geomvsnet/`. And perform *depth map estimation, point cloud fusion, and result evaluation* according to the following steps.\n1. Run `bash scripts/dtu/test_dtu.sh` for depth map estimation. The results will be stored in `outputs/dtu/[THISNAME]/`, each scan folder holding `depth_est` and `confidence`, etc.\n    - Use `outputs/visual.ipynb` for depth map visualization.\n2. Run `bash scripts/dtu/fusion_dtu.sh` for point cloud fusion. We provide 3 different fusion methods, and we recommend the `open3d` option by default. After fusion, you can get `[FUSION_METHOD]_fusion_plys` under the experiment output folder, point clouds of each testing scan are there.\n\n    <details>\n    <summary>(Optional) If you want to use the \"Gipuma\" fusion method.</summary>\n\n    1. Clone the [edited fusibile repo](https://github.com/YoYo000/fusibile).\n    2. Refer to [fusibile configuration blog (Chinese)](https://zhuanlan.zhihu.com/p/460212787) for building details.\n    3. Create a new python2.7 conda env.\n        ```bash\n        conda create -n fusibile python=2.7\n        conda install scipy matplotlib\n        conda install tensorflow==1.14.0\n        conda install -c https://conda.anaconda.org/menpo opencv\n        ```\n    4. Use the `fusibile` conda environment for `gipuma` fusion method.\n\n    </details>\n\n3. Download the [ObsMask](http://roboimagedata2.compute.dtu.dk/data/MVS/SampleSet.zip) and [Points](http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip) of DTU GT point clouds from the official website and organize them as:\n\n    ```\n    dtu-evaluation/\n    ├── ObsMask\n    └── Points\n    ```\n\n4. Setup `Matlab` in command line mode, and run `bash scripts/dtu/matlab_quan_dtu.sh`. You can adjust the `num_at_once` config according to your machine's CPU and memory ceiling. After quantitative evaluation, you will get `[FUSION_METHOD]_quantitative/` and `[THISNAME].log` just store the quantitative results.\n\n### 3.2 Tanks and Temples\n\nFor testing on [Tanks and Temples benchmark](https://www.tanksandtemples.org/leaderboard/), you can use any of the following configurations:\n- Only train on DTU training dataset.\n- Only train on BlendedMVS dataset.\n- Pretrained on DTU training dataset and finetune on BlendedMVS dataset. (Recommend)\n\nAfter your personal training, also follow these steps:\n1. Run `bash scripts/tnt/test_tnt.sh` for depth map estimation. The results will be stored in `outputs/[TRAINING_DATASET]/[THISNAME]/`.\n    - Use `outputs/visual.ipynb` for depth map visualization.\n2. Run `bash scripts/tnt/fusion_tnt.sh` for point cloud fusion. We provide the popular dynamic fusion strategy, and you can tune the fusion threshold in `fusions/tnt/dypcd.py`.\n3. Follow the *Upload Instructions* on the [T&T official website](https://www.tanksandtemples.org/submit/) to make online submissions.\n\n### 3.3 Custom Data (TODO)\n\nGeoMVSNet can reconstruct on custom data. At present, you can refer to [MVSNet](https://github.com/YoYo000/MVSNet#file-formats) to organize your data, and refer to the same steps as above for *depth estimation* and *point cloud fusion*.\n\n## 💡 Results\n\nOur results on DTU and Tanks and Temples Dataset are listed in the tables.\n\n| DTU Dataset | Acc. ↓ | Comp. ↓ | Overall ↓ |\n| ----------- | ------ | ------- | --------- |\n| GeoMVSNet   | 0.3309 | 0.2593  | 0.2951    |\n\n| T&T (Intermediate) | Mean ↑ | Family | Francis | Horse | Lighthouse | M60   | Panther | Playground | Train |\n| ------------------ | ------ | ------ | ------- | ----- | ---------- | ----- | ------- | ---------- | ----- |\n| GeoMVSNet          | 65.89  | 81.64  | 67.53   | 55.78 | 68.02      | 65.49 | 67.19   | 63.27      | 58.22 |\n\n| T&T (Advanced) | Mean ↑ | Auditorium | Ballroom | Courtroom | Museum | Palace | Temple |\n| -------------- | ------ | ---------- | -------- | --------- | ------ | ------ | ------ |\n| GeoMVSNet      | 41.52  | 30.23      | 46.53    | 39.98     | 53.05  | 35.98  | 43.34  |\n\nAnd you can download our [Point Cloud](https://disk.pku.edu.cn:443/link/69D473126C509C8DCBCC7E233FAAEEAA) and [Estimated Depth](https://disk.pku.edu.cn:443/link/4217EB2F063D2B10EDC711F54A12B5F7) for academic usage.\n\n<details>\n<summary>🌟 About Reproduce Paper Results</summary>\n\n\nIn our experiment, we found that the reproduction of MVS network is relatively difficult. Therefore, we summarize some of the problems encountered in our experiment as follows, hoping to be helpful to you.\n\n**Q1. GPU Architecture Matters.**\n\nThere are two commonly used NVIDIA GPU series: GeForce RTX (e.g. 4090Ti, 3090Ti, 2090Ti) and Tesla (e.g. V100, T4). We find that there is generally no performance degradation in training and testing on the same series of GPUs. But on the contrary, for example, if you train on V100 and test on 3090Ti, the visual effect of the depth map looks exactly the same, but each pixel value is not exactly the same. We conjecture that the two series or architectures differ in numerical computation and processing precision.\n\n> Our pretrained model is trained on NVIDIA V100 GPUs.\n\n**Q2. Pytorch Version Matters.**\n\nDifferent Cuda versions will result in different optional Pytorch versions. Different torch versions will affect the accuracy of network training and testing. One of the reasons we found is that the implementation and parameter control of the `F.grid_sample()` are various in different versions of Pytorch.\n\n**Q3. Training Hyperparameters Matters.**\n\nIn the era of neural network, hyperparameters really matter. We made some network hyperparameters tuning, but it may not be the same as your configuration. Most fundamentally, due to differences in GPU graphics memory, you need to synchronize `batch_size` and `lr`. And the schedule of learning rate also matters.\n\n**Q4. Testing Epoch Matters.**\n\nBy default, our model will train 16 epochs. But how to select the best training model for testing to achieve the best performance? One solution is to use [PyTorch-lightning](https://lightning.ai/docs/pytorch/latest/starter/introduction.html). For simplicity, you can decide which checkpoint to use based on the `.log` file we provide.\n\n**Q5. Fusion Hyperparameters Matters.**\n\nFor both DTU and T&T datasets, the hyperparameters of point cloud fusion greatly affect the final performance. We have provided different fusion strategies and easy access to adjust parameters. Maybe you need to know the temperament of your model.\n\nQx. Others, you can [raise an issue](https://github.com/doubleZ0108/GeoMVSNet/issues/new/choose) if you meet other problems.\n\n</details>\n\n<br />\n\n## ⚖️ Citation\n```\n@InProceedings{zhe2023geomvsnet,\n  title={GeoMVSNet: Learning Multi-View Stereo With Geometry Perception},\n  author={Zhang, Zhe and Peng, Rui and Hu, Yuxi and Wang, Ronggang},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={21508--21518},\n  year={2023}\n}\n```\n\n## 💌 Acknowledgements\n\nThis repository is partly based on [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [CVP-MVSNet](https://github.com/JiayuYANG/CVP-MVSNet), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [MVSTER](https://github.com/JeffWang987/MVSTER).\n\nWe appreciate their contributions to the MVS community."
  },
  {
    "path": "datasets/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/blendedmvs.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Data preprocessing and organization for BlendedMVS dataset.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os\nimport cv2\nimport random\nimport numpy as np\nfrom PIL import Image\n\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms as T\n\nfrom datasets.data_io import *\n\n\ndef motion_blur(img: np.ndarray, max_kernel_size=3):\n    # Either vertial, hozirontal or diagonal blur\n    mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])\n    ksize = np.random.randint(0, (max_kernel_size + 1) / 2) * 2 + 1  # make sure is odd\n    center = int((ksize - 1) / 2)\n    kernel = np.zeros((ksize, ksize))\n    if mode == 'h':\n        kernel[center, :] = 1.\n    elif mode == 'v':\n        kernel[:, center] = 1.\n    elif mode == 'diag_down':\n        kernel = np.eye(ksize)\n    elif mode == 'diag_up':\n        kernel = np.flip(np.eye(ksize), 0)\n    var = ksize * ksize / 16.\n    grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)\n    gaussian = np.exp(-(np.square(grid - center) + np.square(grid.T - center)) / (2. * var))\n    kernel *= gaussian\n    kernel /= np.sum(kernel)\n    img = cv2.filter2D(img, -1, kernel)\n    return img\n\n\nclass BlendedMVSDataset(Dataset):\n    def __init__(self, root_dir, list_file, split, n_views, **kwargs):\n        super(BlendedMVSDataset, self).__init__()\n\n        self.levels = 4 \n        self.root_dir = root_dir\n        self.list_file = list_file\n        self.split = split\n        self.n_views = n_views\n\n        assert self.split in ['train', 'val', 'all']\n\n        self.scale_factors = {}\n        self.scale_factor = 0\n\n        self.img_wh = kwargs.get(\"img_wh\", (768, 576))\n        assert self.img_wh[0]%32==0 and self.img_wh[1]%32==0, \\\n            'img_wh must both be multiples of 2^5!'\n        \n        self.robust_train = kwargs.get(\"robust_train\", True)\n        self.augment = kwargs.get(\"augment\", True)\n        if self.augment:\n            self.color_augment = T.ColorJitter(brightness=0.25, contrast=(0.3, 1.5))\n\n        self.metas = self.build_metas()\n\n\n    def build_metas(self):\n        metas = []\n        with open(self.list_file) as f:\n            self.scans = [line.rstrip() for line in f.readlines()]\n        for scan in self.scans:\n            with open(os.path.join(self.root_dir, scan, \"cams/pair.txt\")) as f:\n                num_viewpoint = int(f.readline())\n                for _ in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    if len(src_views) >= self.n_views-1:\n                        metas += [(scan, ref_view, src_views)]\n        return metas\n\n\n    def read_cam_file(self, scan, filename):\n        with open(filename) as f:\n            lines = f.readlines()\n            lines = [line.rstrip() for line in lines]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n        depth_min = float(lines[11].split()[0])\n        depth_max = float(lines[11].split()[-1])\n\n        if scan not in self.scale_factors:\n            self.scale_factors[scan] = 100.0 / depth_min\n        depth_min *= self.scale_factors[scan]\n        depth_max *= self.scale_factors[scan]\n        extrinsics[:3, 3] *= self.scale_factors[scan]\n\n        return intrinsics, extrinsics, depth_min, depth_max\n\n\n    def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):\n        depth = np.array(read_pfm(filename)[0], dtype=np.float32)\n        depth = (depth * self.scale_factors[scan]) * scale\n\n        mask = (depth>=depth_min) & (depth<=depth_max)\n        assert mask.sum() > 0\n        mask = mask.astype(np.float32)\n        if self.img_wh is not None:\n            depth = cv2.resize(depth, self.img_wh, interpolation=cv2.INTER_NEAREST)\n        h, w = depth.shape\n        depth_ms = {}\n        mask_ms = {}\n\n        for i in range(self.levels):\n            depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)\n            mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)\n\n            depth_ms[f\"stage{self.levels-i}\"] = depth_cur\n            mask_ms[f\"stage{self.levels-i}\"] = mask_cur\n\n        return depth_ms, mask_ms\n\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n\n        if self.augment:\n            img = self.color_augment(img)\n            img = motion_blur(np.array(img, dtype=np.float32))\n\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n\n    def __len__(self):\n        return len(self.metas)\n\n\n    def __getitem__(self, idx):\n        meta = self.metas[idx]\n        scan, ref_view, src_views = meta\n        \n        if self.robust_train:\n            num_src_views = len(src_views)\n            index = random.sample(range(num_src_views), self.n_views - 1)\n            view_ids = [ref_view] + [src_views[i] for i in index]\n            scale_ratio = random.uniform(0.8, 1.25)\n        else:\n            view_ids = [ref_view] + src_views[:self.n_views - 1]\n            scale_ratio = 1\n\n        imgs = []\n        mask = None\n        depth = None\n        depth_min = None\n        depth_max = None\n\n        proj={}\n        proj_matrices_0 = []\n        proj_matrices_1 = []\n        proj_matrices_2 = []\n        proj_matrices_3 = []\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.root_dir, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))\n            depth_filename = os.path.join(self.root_dir, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))\n            proj_mat_filename = os.path.join(self.root_dir, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))\n\n            img = self.read_img(img_filename)\n            imgs.append(img.transpose(2,0,1))\n\n            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)\n\n            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            extrinsics[:3, 3] *= scale_ratio\n            intrinsics[:2,:] *= 0.125\n            proj_mat_0[0,:4,:4] = extrinsics.copy()\n            proj_mat_0[1,:3,:3] = intrinsics.copy()\n            int_mat_0 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_1[0,:4,:4] = extrinsics.copy()\n            proj_mat_1[1,:3,:3] = intrinsics.copy()\n            int_mat_1 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_2[0,:4,:4] = extrinsics.copy()\n            proj_mat_2[1,:3,:3] = intrinsics.copy()\n            int_mat_2 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_3[0,:4,:4] = extrinsics.copy()\n            proj_mat_3[1,:3,:3] = intrinsics.copy()\n            int_mat_3 = intrinsics.copy()\n\n            proj_matrices_0.append(proj_mat_0)\n            proj_matrices_1.append(proj_mat_1)\n            proj_matrices_2.append(proj_mat_2)\n            proj_matrices_3.append(proj_mat_3)\n\n            # reference view\n            if i == 0:\n                depth_min = depth_min_ * scale_ratio\n                depth_max = depth_max_ * scale_ratio\n                depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale_ratio)\n                for l in range(self.levels):\n                    mask[f'stage{l+1}'] = mask[f'stage{l+1}']\n                    depth[f'stage{l+1}'] = depth[f'stage{l+1}']\n\n        proj['stage1'] = np.stack(proj_matrices_0)\n        proj['stage2'] = np.stack(proj_matrices_1)\n        proj['stage3'] = np.stack(proj_matrices_2)\n        proj['stage4'] = np.stack(proj_matrices_3)\n\n        intrinsics_matrices = {\n            \"stage1\": int_mat_0,\n            \"stage2\": int_mat_1,\n            \"stage3\": int_mat_2,\n            \"stage4\": int_mat_3\n        }\n        \n        sample = {\n            \"imgs\": imgs,\n            \"proj_matrices\": proj,\n            \"intrinsics_matrices\": intrinsics_matrices,\n            \"depth\": depth,\n            \"depth_values\": np.array([depth_min, depth_max], dtype=np.float32),\n            \"mask\": mask\n        }\n\n        return sample"
  },
  {
    "path": "datasets/data_io.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: I/O functions for depth maps and camera files.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport sys, re\nimport numpy as np\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\ndef save_pfm(filename, image, scale=1):\n    file = open(filename, \"wb\")\n    color = None\n\n    image = np.flipud(image)\n\n    if image.dtype.name != 'float32':\n        raise Exception('Image dtype must be float32.')\n\n    if len(image.shape) == 3 and image.shape[2] == 3:  # color image\n        color = True\n    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:  # greyscale\n        color = False\n    else:\n        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')\n\n    file.write('PF\\n'.encode('utf-8') if color else 'Pf\\n'.encode('utf-8'))\n    file.write('{} {}\\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))\n\n    endian = image.dtype.byteorder\n\n    if endian == '<' or endian == '=' and sys.byteorder == 'little':\n        scale = -scale\n\n    file.write(('%f\\n' % scale).encode('utf-8'))\n\n    image.tofile(file)\n    file.close()\n\n\ndef write_cam(file, cam):\n    f = open(file, \"w\")\n    f.write('extrinsic\\n')\n    for i in range(0, 4):\n        for j in range(0, 4):\n            f.write(str(cam[0][i][j]) + ' ')\n        f.write('\\n')\n    f.write('\\n')\n\n    f.write('intrinsic\\n')\n    for i in range(0, 3):\n        for j in range(0, 3):\n            f.write(str(cam[1][i][j]) + ' ')\n        f.write('\\n')\n\n    f.write('\\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\\n')\n\n    f.close()"
  },
  {
    "path": "datasets/dtu.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Data preprocessing and organization for DTU dataset.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os\nimport cv2\nimport random\nimport numpy as np\nfrom PIL import Image\n\nfrom torchvision import transforms\nfrom torch.utils.data import Dataset\n\nfrom datasets.data_io import *\n\n\nclass DTUDataset(Dataset):\n    def __init__(self, root_dir, list_file, mode, n_views, **kwargs):\n        super(DTUDataset, self).__init__()\n        \n        self.root_dir = root_dir\n        self.list_file = list_file\n        self.mode = mode\n        self.n_views = n_views\n\n        assert self.mode in [\"train\", \"val\", \"test\"]\n\n        self.total_depths = 192\n        self.interval_scale = 1.06\n\n        self.data_scale = kwargs.get(\"data_scale\", \"mid\")     # mid / raw\n        self.robust_train = kwargs.get(\"robust_train\", False)   # True / False\n        self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)\n\n        if self.mode == \"test\":\n            self.max_wh = kwargs.get(\"max_wh\", (1600, 1200))\n\n        self.metas = self.build_metas()\n\n    \n    def build_metas(self):\n        metas = []\n\n        with open(os.path.join(self.list_file)) as f:\n            scans = [line.rstrip() for line in f.readlines()]\n\n        pair_file = \"Cameras/pair.txt\"\n        for scan in scans:\n            with open(os.path.join(self.root_dir, pair_file)) as f:\n                num_viewpoint = int(f.readline())\n\n                # viewpoints (49)\n                for _ in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n\n                    if self.mode == \"train\":\n                        # light conditions 0-6\n                        for light_idx in range(7):\n                            metas.append((scan, light_idx, ref_view, src_views))\n                    elif self.mode in [\"test\", \"val\"]:\n                        if len(src_views) < self.n_views:\n                            print(\"{} < num_views:{}\".format(len(src_views), self.n_views))\n                            src_views += [src_views[0]] * (self.n_views - len(src_views))\n                        metas.append((scan, 3, ref_view, src_views))\n\n        print(\"DTU Dataset in\", self.mode, \"mode metas:\", len(metas))\n        return metas\n\n\n    def __len__(self):\n        return len(self.metas)\n\n\n    def read_cam_file(self, filename):\n        with open(filename) as f:\n            lines = f.readlines()\n            lines = [line.rstrip() for line in lines]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n\n        if self.mode == \"test\":\n            intrinsics[:2, :] /= 4.0\n\n        # depth_min & depth_interval: line 11\n        depth_min = float(lines[11].split()[0])\n        depth_interval = float(lines[11].split()[1])\n        \n        if len(lines[11].split()) >= 3:\n            num_depth = lines[11].split()[2]\n            depth_max = depth_min + int(float(num_depth)) * depth_interval\n            depth_interval = (depth_max - depth_min) / self.total_depths\n\n        depth_interval *= self.interval_scale\n\n        return intrinsics, extrinsics, depth_min, depth_interval\n\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        if self.mode == \"train\" and self.robust_train:\n            img = self.color_augment(img)\n        # scale 0~255 to 0~1\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n\n    def crop_img(self, img):\n        raw_h, raw_w = img.shape[:2]\n        start_h = (raw_h-1024)//2\n        start_w = (raw_w-1280)//2\n        return img[start_h:start_h+1024, start_w:start_w+1280, :]  # (1024, 1280)\n\n    \n    def prepare_img(self, hr_img):\n        h, w = hr_img.shape\n        if self.data_scale == \"mid\":\n            hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)\n            h, w = hr_img_ds.shape\n            target_h, target_w = 512, 640\n            start_h, start_w = (h - target_h)//2, (w - target_w)//2\n            hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w]\n        elif self.data_scale == \"raw\":\n            hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2]  # (1024, 1280)\n        return hr_img_crop\n\n    \n    def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):\n        h, w = img.shape[:2]\n        if h > max_h or w > max_w:\n            scale = 1.0 * max_h / h\n            if scale * w > max_w:\n                scale = 1.0 * max_w / w\n            new_w, new_h = scale * w // base * base, scale * h // base * base\n        else:\n            new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base\n\n        scale_w = 1.0 * new_w / w\n        scale_h = 1.0 * new_h / h\n        intrinsics[0, :] *= scale_w\n        intrinsics[1, :] *= scale_h\n\n        img = cv2.resize(img, (int(new_w), int(new_h)))\n\n        return img, intrinsics\n\n    \n    def read_mask_hr(self, filename):\n        img = Image.open(filename)\n        np_img = np.array(img, dtype=np.float32)\n        np_img = (np_img > 10).astype(np.float32)\n        np_img = self.prepare_img(np_img)\n\n        h, w = np_img.shape\n        np_img_ms = {\n            \"stage1\": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST),\n            \"stage2\": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST),\n            \"stage3\": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST),\n            \"stage4\": np_img,\n        }\n        return np_img_ms\n\n\n    def read_depth_hr(self, filename, scale):\n        depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale\n        depth_lr = self.prepare_img(depth_hr)\n\n        h, w = depth_lr.shape\n        depth_lr_ms = {\n            \"stage1\": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST),\n            \"stage2\": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST),\n            \"stage3\": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST),\n            \"stage4\": depth_lr,\n        }\n        return depth_lr_ms\n\n\n    def __getitem__(self, idx):\n        scan, light_idx, ref_view, src_views = self.metas[idx]\n\n        if self.mode == \"train\" and self.robust_train:\n            num_src_views = len(src_views)\n            index = random.sample(range(num_src_views), self.n_views-1)\n            view_ids = [ref_view] + [src_views[i] for i in index]\n            scale_ratio = random.uniform(0.8, 1.25) \n        else:\n            view_ids = [ref_view] + src_views[:self.n_views-1]\n            scale_ratio = 1\n\n        imgs = []\n        mask = None\n        depth_values = None\n        proj_matrices = []\n\n        for i, vid in enumerate(view_ids):\n            # @Note image & cam\n            if self.mode in [\"train\", \"val\"]:\n                if self.data_scale == \"mid\":\n                    img_filename = os.path.join(self.root_dir, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid+1, light_idx))\n                elif self.data_scale == \"raw\":\n                    img_filename = os.path.join(self.root_dir, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))\n                proj_mat_filename = os.path.join(self.root_dir, 'Cameras/train/{:0>8}_cam.txt').format(vid)\n            elif self.mode == \"test\":\n                img_filename = os.path.join(self.root_dir, 'Rectified/{}/rect_{:0>3}_3_r5000.png'.format(scan, vid+1))\n                proj_mat_filename = os.path.join(self.root_dir, 'Cameras/{:0>8}_cam.txt'.format(vid))\n\n            img = self.read_img(img_filename)\n            intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)\n\n            if self.mode in [\"train\", \"val\"]:\n                if self.data_scale == \"raw\":\n                    img = self.crop_img(img)\n                    intrinsics[:2, :] *= 2.0\n                if self.mode == \"train\" and self.robust_train:\n                    extrinsics[:3,3] *= scale_ratio                    \n            elif self.mode == \"test\":\n                img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_wh[0], self.max_wh[1])\n\n            imgs.append(img.transpose(2,0,1))\n\n            # reference view\n            if i == 0:\n                # @Note depth values\n                diff = 0.5 if self.mode in [\"test\", \"val\"] else 0\n                depth_max = depth_interval * (self.total_depths - diff) + depth_min\n                depth_values = np.array([depth_min * scale_ratio, depth_max * scale_ratio], dtype=np.float32)\n\n                # @Note depth & mask\n                if self.mode in [\"train\", \"val\"]:\n                    depth_filename_hr = os.path.join(self.root_dir, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))\n                    depth = self.read_depth_hr(depth_filename_hr, scale_ratio)\n\n                    mask_filename_hr = os.path.join(self.root_dir, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))\n                    mask = self.read_mask_hr(mask_filename_hr)\n\n            proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat[0, :4, :4] = extrinsics\n            proj_mat[1, :3, :3] = intrinsics\n            proj_matrices.append(proj_mat)\n            \n        proj_matrices = np.stack(proj_matrices)\n        intrinsics = np.stack(intrinsics)\n        stage1_pjmats = proj_matrices.copy()\n        stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0\n        stage1_ins = intrinsics.copy()\n        stage1_ins[:2, :] = intrinsics[:2, :] / 2.0\n        stage3_pjmats = proj_matrices.copy()\n        stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2\n        stage3_ins = intrinsics.copy()\n        stage3_ins[:2, :] = intrinsics[:2, :] * 2.0\n        stage4_pjmats = proj_matrices.copy()\n        stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4\n        stage4_ins = intrinsics.copy()\n        stage4_ins[:2, :] = intrinsics[:2, :] * 4.0\n        proj_matrices = {\n            \"stage1\": stage1_pjmats,\n            \"stage2\": proj_matrices,\n            \"stage3\": stage3_pjmats,\n            \"stage4\": stage4_pjmats\n        }\n        intrinsics_matrices = {\n            \"stage1\": stage1_ins,\n            \"stage2\": intrinsics,\n            \"stage3\": stage3_ins,\n            \"stage4\": stage4_ins\n        }\n\n        sample = {\n            \"imgs\": imgs,\n            \"proj_matrices\": proj_matrices,\n            \"intrinsics_matrices\": intrinsics_matrices,\n            \"depth_values\": depth_values\n        }\n        if self.mode in [\"train\", \"val\"]:\n            sample[\"depth\"] = depth\n            sample[\"mask\"] = mask\n        elif self.mode == \"test\":\n            sample[\"filename\"] = scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + \"{}\"\n\n        return sample"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/BaseEval2Obj_web.m",
    "content": "function BaseEval2Obj_web(BaseEval,method_string,outputPath)\r\n\r\nif(nargin<3)\r\n    outputPath='./';\r\nend\r\n\r\n% tresshold for coloring alpha channel in the range of 0-10 mm\r\ndist_tresshold=10;\r\n\r\ncSet=BaseEval.cSet;\r\n\r\nQdata=BaseEval.Qdata;\r\nalpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;\r\n\r\nfid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');\r\n\r\nfor cP=1:size(Qdata,2)\r\n    if(BaseEval.DataInMask(cP))\r\n        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)\r\n    else\r\n        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)\r\n    end\r\n    fprintf(fid,'v %f %f %f %f %f %f\\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);\r\nend\r\nfclose(fid);\r\n\r\ndisp('Data2Stl saved as obj')\r\n\r\nQstl=BaseEval.Qstl;\r\nfid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');\r\n\r\nalpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;\r\n\r\nfor cP=1:size(Qstl,2)\r\n    if(BaseEval.StlAbovePlane(cP))\r\n        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)\r\n    else\r\n        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)\r\n    end\r\n    fprintf(fid,'v %f %f %f %f %f %f\\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);\r\nend\r\nfclose(fid);\r\n\r\ndisp('Stl2Data saved as obj')"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/BaseEvalMain_web.m",
    "content": "format compact\r\n\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n\r\ndst=0.2;    %Min dist between points when reducing\r\n\r\n% start this evaluation\r\ncSet = str2num(thisset)\r\n\r\n%input data name\r\nDataInName = [plyPath sprintf('%s%03d.ply', lower(method_string), cSet)]\r\n\r\n\r\n\r\n%results name\r\nEvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']\r\n\r\n%check if file is already computed\r\nif(~exist(EvalName,'file'))\r\n    disp(DataInName);\r\n    \r\n    time=clock;time(4:5), drawnow\r\n    \r\n    tic\r\n    Mesh = plyread(DataInName);\r\n    Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';\r\n    toc\r\n    \r\n    BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);\r\n    \r\n    disp('Saving results'), drawnow\r\n    toc\r\n    save(EvalName,'BaseEval');\r\n    toc\r\n    \r\n    % write obj-file of evaluation\r\n    % BaseEval2Obj_web(BaseEval,method_string, resultsPath)\r\n    % toc\r\n    time=clock;time(4:5), drawnow\r\n\r\n    BaseEval.MaxDist=20; %outlier threshold of 20 mm\r\n    \r\n    BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane \r\n    BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers\r\n\r\n    BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask\r\n    BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers\r\n    \r\n    fprintf(\"mean/median Data (acc.) %f/%f\\n\", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));\r\n    fprintf(\"mean/median Stl (comp.) %f/%f\\n\", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));\r\nend\r\n\r\nfprintf(\"=== %d done! ===\\n\", cSet)\r\n\r\nexit"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/ComputeStat_web.m",
    "content": "format compact\r\n\r\n\r\nMaxDist=20; %outlier thresshold of 20 mm\r\n\r\ntime=clock;\r\n\r\n% method_string='mvsnet';\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n\r\nUsedSets=str2num(set)\r\n\r\nnStat=length(UsedSets);\r\n\r\nBaseStat.nStl=zeros(1,nStat);\r\nBaseStat.nData=zeros(1,nStat);\r\nBaseStat.MeanStl=zeros(1,nStat);\r\nBaseStat.MeanData=zeros(1,nStat);\r\nBaseStat.VarStl=zeros(1,nStat);\r\nBaseStat.VarData=zeros(1,nStat);\r\nBaseStat.MedStl=zeros(1,nStat);\r\nBaseStat.MedData=zeros(1,nStat);\r\n\r\nfor cStat=1:length(UsedSets) %Data set number\r\n    \r\n    currentSet=UsedSets(cStat);\r\n\r\n    EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];\r\n    \r\n    disp(EvalName);\r\n    load(EvalName);\r\n    \r\n    Dstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane \r\n    Dstl=Dstl(Dstl<MaxDist); % discard outliers\r\n    \r\n    Ddata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask\r\n    Ddata=Ddata(Ddata<MaxDist); % discard outliers\r\n    \r\n    BaseStat.nStl(cStat)=length(Dstl);\r\n    BaseStat.nData(cStat)=length(Ddata);\r\n    \r\n    BaseStat.MeanStl(cStat)=mean(Dstl);\r\n    BaseStat.MeanData(cStat)=mean(Ddata);\r\n    \r\n    BaseStat.VarStl(cStat)=var(Dstl);\r\n    BaseStat.VarData(cStat)=var(Ddata);\r\n    \r\n    BaseStat.MedStl(cStat)=median(Dstl);\r\n    BaseStat.MedData(cStat)=median(Ddata);\r\n    \r\n    disp(\"acc\");\r\n    disp(mean(Ddata));\r\n    disp(\"comp\");\r\n    disp(mean(Dstl));\r\n    time=clock;\r\nend\r\n\r\ndisp(BaseStat);\r\ndisp(\"mean acc\")\r\ndisp(mean(BaseStat.MeanData));\r\ndisp(\"mean comp\")\r\ndisp(mean(BaseStat.MeanStl));\r\n\r\ntotalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']\r\nsave(totalStatName,'BaseStat','time','MaxDist');\r\n\r\n\r\nexit"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/MaxDistCP.m",
    "content": "function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)\r\n\r\nDist=ones(1,size(Qfrom,2))*MaxDist;\r\n\r\nRange=floor((BB(2,:)-BB(1,:))/MaxDist);\r\n\r\ntic\r\nDone=0;\r\nLookAt=zeros(1,size(Qfrom,2));\r\nfor x=0:Range(1),\r\n    for y=0:Range(2),\r\n        for z=0:Range(3),\r\n            \r\n            Low=BB(1,:)+[x y z]*MaxDist;\r\n            High=Low+MaxDist;\r\n            \r\n            idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...\r\n                Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));\r\n            SQfrom=Qfrom(:,idxF);\r\n            LookAt(idxF)=LookAt(idxF)+1; %Debug\r\n            \r\n            Low=Low-MaxDist;\r\n            High=High+MaxDist;\r\n            idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...\r\n                Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));\r\n            SQto=Qto(:,idxT);\r\n            \r\n            if(isempty(SQto))\r\n                Dist(idxF)=MaxDist;\r\n            else\r\n                KDstl=KDTreeSearcher(SQto');\r\n                [~,SDist] = knnsearch(KDstl,SQfrom');\r\n                Dist(idxF)=SDist;\r\n                \r\n            end\r\n            \r\n            Done=Done+length(idxF); %Debug\r\n            \r\n        end\r\n    end\r\n    %Complete=Done/size(Qfrom,2);\r\n    %EstTime=(toc/Complete)/60\r\n    %toc\r\n    %LA=[sum(LookAt==0),...\r\n    %\tsum(LookAt==1),...\r\n   % \tsum(LookAt==2),...\r\n   % \tsum(LookAt==3),...\r\n   % \tsum(LookAt>3)]\r\nend"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/PointCompareMain.m",
    "content": "function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)\r\n% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the\r\n% distances from the evaluation points to the reference\r\n\r\ntic\r\n% reduce points 0.2 mm neighbourhood density\r\nQdata=reducePts_haa(Qdata,dst);\r\ntoc\r\n\r\nStlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];\r\n\r\nStlMesh = plyread(StlInName);  %STL points already reduced 0.2 mm neighbourhood density\r\nQstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';\r\n\r\n%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)\r\nMargin=10;\r\nMaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];\r\nload(MaskName)\r\n\r\nMaxDist=60;\r\ndisp('Computing Data 2 Stl distances')\r\nDdata = MaxDistCP(Qstl,Qdata,BB,MaxDist);\r\ntoc\r\n\r\ndisp('Computing Stl 2 Data distances')\r\nDstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);\r\ndisp('Distances computed')\r\ntoc\r\n\r\n%use mask\r\n%From Get mask - inverted & modified.\r\nOne=ones(1,size(Qdata,2));\r\nQv=(Qdata-BB(1,:)'*One)/Res+1;\r\nQv=round(Qv);\r\n\r\nMidx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));\r\nMidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));\r\nMidx2=find(ObsMask(MidxA));\r\n\r\nBaseEval.DataInMask(1:size(Qv,2))=false;\r\nBaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask\r\n\r\nBaseEval.cSet=cSet;\r\nBaseEval.Margin=Margin;         %Margin of masks\r\nBaseEval.dst=dst;               %Min dist between points when reducing\r\nBaseEval.Qdata=Qdata;           %Input data points\r\nBaseEval.Ddata=Ddata;           %distance from data to stl\r\nBaseEval.Qstl=Qstl;             %Input stl points\r\nBaseEval.Dstl=Dstl;             %Distance from the stl to data\r\n\r\nload([dataPath '/ObsMask/Plane' num2str(cSet)],'P')\r\nBaseEval.GroundPlane=P;         % Plane used to destinguise which Stl points are 'used'\r\nBaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'\r\nBaseEval.Time=clock;            %Time when computation is finished"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/plyread.m",
    "content": "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\r\nfunction [Elements,varargout] = plyread(Path,Str)\r\n%PLYREAD   Read a PLY 3D data file.\r\n%   [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file\r\n%   FILENAME and returns a structure DATA.  The fields in this structure\r\n%   are defined by the PLY header; each element type is a field and each\r\n%   element property is a subfield.  If the file contains any comments,\r\n%   they are returned in a cell string array COMMENTS.\r\n%\r\n%   [TRI,PTS] = PLYREAD(FILENAME,'tri') or\r\n%   [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex\r\n%   and face data into triangular connectivity and vertex arrays.  The\r\n%   mesh can then be displayed using the TRISURF command.\r\n%\r\n%   Note: This function is slow for large mesh files (+50K faces),\r\n%   especially when reading data with list type properties.\r\n%\r\n%   Example:\r\n%   [Tri,Pts] = PLYREAD('cow.ply','tri');\r\n%   trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); \r\n%   colormap(gray); axis equal;\r\n%\r\n%   See also: PLYWRITE\r\n\r\n% Pascal Getreuer 2004\r\n\r\n[fid,Msg] = fopen(Path,'rt');\t% open file in read text mode\r\n\r\nif fid == -1, error(Msg); end\r\n\r\nBuf = fscanf(fid,'%s',1);\r\nif ~strcmp(Buf,'ply')\r\n   fclose(fid);\r\n   error('Not a PLY file.'); \r\nend\r\n\r\n\r\n%%% read header %%%\r\n\r\nPosition = ftell(fid);\r\nFormat = '';\r\nNumComments = 0;\r\nComments = {};\t\t\t\t% for storing any file comments\r\nNumElements = 0;\r\nNumProperties = 0;\r\nElements = [];\t\t\t\t% structure for holding the element data\r\nElementCount = [];\t\t% number of each type of element in file\r\nPropertyTypes = [];\t\t% corresponding structure recording property types\r\nElementNames = {};\t\t% list of element names in the order they are stored in the file\r\nPropertyNames = [];\t\t% structure of lists of property names\r\n\r\nwhile 1\r\n   Buf = fgetl(fid);   \t\t\t\t\t\t\t\t% read one line from file\r\n   BufRem = Buf;\r\n   Token = {};\r\n   Count = 0;\r\n   \r\n   while ~isempty(BufRem)\t\t\t\t\t\t\t\t% split line into tokens\r\n      [tmp,BufRem] = strtok(BufRem);\r\n      \r\n      if ~isempty(tmp)\r\n         Count = Count + 1;\t\t\t\t\t\t\t% count tokens\r\n         Token{Count} = tmp;\r\n      end\r\n   end\r\n   \r\n   if Count \t\t% parse line\r\n      switch lower(Token{1})\r\n      case 'format'\t\t% read data format\r\n         if Count >= 2\r\n            Format = lower(Token{2});\r\n            \r\n            if Count == 3 & ~strcmp(Token{3},'1.0')\r\n               fclose(fid);\r\n               error('Only PLY format version 1.0 supported.');\r\n            end\r\n         end\r\n      case 'comment'\t\t% read file comment\r\n         NumComments = NumComments + 1;\r\n         Comments{NumComments} = '';\r\n         for i = 2:Count\r\n            Comments{NumComments} = [Comments{NumComments},Token{i},' '];\r\n         end\r\n      case 'element'\t\t% element name\r\n         if Count >= 3\r\n            if isfield(Elements,Token{2})\r\n               fclose(fid);\r\n               error(['Duplicate element name, ''',Token{2},'''.']);\r\n            end\r\n            \r\n            NumElements = NumElements + 1;\r\n            NumProperties = 0;\r\n   \t      Elements = setfield(Elements,Token{2},[]);\r\n            PropertyTypes = setfield(PropertyTypes,Token{2},[]);\r\n            ElementNames{NumElements} = Token{2};\r\n            PropertyNames = setfield(PropertyNames,Token{2},{});\r\n            CurElement = Token{2};\r\n            ElementCount(NumElements) = str2double(Token{3});\r\n            \r\n            if isnan(ElementCount(NumElements))\r\n               fclose(fid);\r\n               error(['Bad element definition: ',Buf]); \r\n            end            \r\n         else\r\n            error(['Bad element definition: ',Buf]);\r\n         end         \r\n      case 'property'\t% element property\r\n         if ~isempty(CurElement) & Count >= 3            \r\n            NumProperties = NumProperties + 1;\r\n            eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');\r\n            \r\n            if tmp\r\n               error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);\r\n            end            \r\n            \r\n            % add property subfield to Elements\r\n            eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');            \r\n            % add property subfield to PropertyTypes and save type\r\n            eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');            \r\n            % record property name order \r\n            eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');\r\n         else\r\n            fclose(fid);\r\n            \r\n            if isempty(CurElement)            \r\n               error(['Property definition without element definition: ',Buf]);\r\n            else               \r\n               error(['Bad property definition: ',Buf]);\r\n            end            \r\n         end         \r\n      case 'end_header'\t% end of header, break from while loop\r\n         break;\t\t\r\n      end\r\n   end\r\nend\r\n\r\n%%% set reading for specified data format %%%\r\n\r\nif isempty(Format)\r\n\twarning('Data format unspecified, assuming ASCII.');\r\n   Format = 'ascii';\r\nend\r\n\r\nswitch Format\r\ncase 'ascii'\r\n   Format = 0;\r\ncase 'binary_little_endian'\r\n   Format = 1;\r\ncase 'binary_big_endian'\r\n   Format = 2;\r\notherwise\r\n   fclose(fid);\r\n   error(['Data format ''',Format,''' not supported.']);\r\nend\r\n\r\nif ~Format   \r\n   Buf = fscanf(fid,'%f');\t\t% read the rest of the file as ASCII data\r\n   BufOff = 1;\r\nelse\r\n   % reopen the file in read binary mode\r\n   fclose(fid);\r\n   \r\n   if Format == 1\r\n      fid = fopen(Path,'r','ieee-le.l64');\t\t% little endian\r\n   else\r\n      fid = fopen(Path,'r','ieee-be.l64');\t\t% big endian\r\n   end\r\n   \r\n   % find the end of the header again (using ftell on the old handle doesn't give the correct position)   \r\n   BufSize = 8192;\r\n   Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];\r\n   i = [];\r\n   tmp = -11;\r\n   \r\n   while isempty(i)\r\n   \ti = findstr(Buf,['end_header',13,10]);\t\t\t% look for end_header + CR/LF\r\n   \ti = [i,findstr(Buf,['end_header',10])];\t\t% look for end_header + LF\r\n      \r\n      if isempty(i)\r\n         tmp = tmp + BufSize;\r\n         Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];\r\n      end\r\n   end\r\n   \r\n   % seek to just after the line feed\r\n   fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);\r\nend\r\n\r\n\r\n%%% read element data %%%\r\n\r\n% PLY and MATLAB data types (for fread)\r\nPlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...\r\n   'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};\r\nMatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};\r\nSizeOf = [1,1,2,2,4,4,4,8];\t% size in bytes of each type\r\n\r\nfor i = 1:NumElements\r\n   % get current element property information\r\n   eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);\r\n   eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);\r\n   NumProperties = size(CurPropertyNames,2);\r\n   \r\n%   fprintf('Reading %s...\\n',ElementNames{i});\r\n      \r\n   if ~Format\t%%% read ASCII data %%%\r\n      for j = 1:NumProperties\r\n         Token = getfield(CurPropertyTypes,CurPropertyNames{j});\r\n         \r\n         if strcmpi(Token{1},'list')\r\n            Type(j) = 1;\r\n         else\r\n            Type(j) = 0;\r\n\t\t\tend\r\n      end\r\n      \r\n      % parse buffer\r\n      if ~any(Type)\r\n         % no list types\r\n         Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';\r\n         BufOff = BufOff + ElementCount(i)*NumProperties;\r\n      else\r\n         ListData = cell(NumProperties,1);\r\n         \r\n         for k = 1:NumProperties\r\n            ListData{k} = cell(ElementCount(i),1);\r\n         end\r\n         \r\n         % list type\r\n\t\t   for j = 1:ElementCount(i)\r\n   \t      for k = 1:NumProperties\r\n      \t      if ~Type(k)\r\n         \t      Data(j,k) = Buf(BufOff);\r\n            \t   BufOff = BufOff + 1;\r\n\t            else\r\n   \t            tmp = Buf(BufOff);\r\n      \t         ListData{k}{j} = Buf(BufOff+(1:tmp))';\r\n         \t      BufOff = BufOff + tmp + 1;\r\n            \tend\r\n            end\r\n         end\r\n      end\r\n   else\t\t%%% read binary data %%%\r\n      % translate PLY data type names to MATLAB data type names\r\n      ListFlag = 0;\t\t% = 1 if there is a list type \r\n      SameFlag = 1;     % = 1 if all types are the same\r\n      \r\n      for j = 1:NumProperties\r\n         Token = getfield(CurPropertyTypes,CurPropertyNames{j});\r\n         \r\n         if ~strcmp(Token{1},'list')\t\t\t% non-list type\r\n\t         tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;\r\n         \r\n            if ~isempty(tmp)\r\n               TypeSize(j) = SizeOf(tmp);\r\n               Type{j} = MatlabTypeNames{tmp};\r\n               TypeSize2(j) = 0;\r\n               Type2{j} = '';\r\n               \r\n               SameFlag = SameFlag & strcmp(Type{1},Type{j});\r\n\t         else\r\n   \t         fclose(fid);\r\n               error(['Unknown property data type, ''',Token{1},''', in ', ...\r\n                     ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n         \tend\r\n         else\t\t\t\t\t\t\t\t\t\t\t% list type\r\n            if length(Token) == 3\r\n               ListFlag = 1;\r\n               SameFlag = 0;\r\n               tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;\r\n               tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;\r\n         \r\n               if ~isempty(tmp) & ~isempty(tmp2)\r\n                  TypeSize(j) = SizeOf(tmp);\r\n                  Type{j} = MatlabTypeNames{tmp};\r\n                  TypeSize2(j) = SizeOf(tmp2);\r\n                  Type2{j} = MatlabTypeNames{tmp2};\r\n\t   \t      else\r\n   \t   \t      fclose(fid);\r\n               \terror(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...\r\n                        ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n               end\r\n            else\r\n               fclose(fid);\r\n               error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n            end\r\n         end\r\n      end\r\n      \r\n      % read file\r\n      if ~ListFlag\r\n         if SameFlag\r\n            % no list types, all the same type (fast)\r\n            Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';\r\n         else\r\n            % no list types, mixed type\r\n            Data = zeros(ElementCount(i),NumProperties);\r\n            \r\n         \tfor j = 1:ElementCount(i)\r\n        \t\t\tfor k = 1:NumProperties\r\n               \tData(j,k) = fread(fid,1,Type{k});\r\n              \tend\r\n         \tend\r\n         end\r\n      else\r\n         ListData = cell(NumProperties,1);\r\n         \r\n         for k = 1:NumProperties\r\n            ListData{k} = cell(ElementCount(i),1);\r\n         end\r\n         \r\n         if NumProperties == 1\r\n            BufSize = 512;\r\n            SkipNum = 4;\r\n            j = 0;\r\n            \r\n            % list type, one property (fast if lists are usually the same length)\r\n            while j < ElementCount(i)\r\n               Position = ftell(fid);\r\n               % read in BufSize count values, assuming all counts = SkipNum\r\n               [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));\r\n               Miss = find(Buf ~= SkipNum);\t\t\t\t\t% find first count that is not SkipNum\r\n               fseek(fid,Position + TypeSize(1),-1); \t\t% seek back to after first count                              \r\n               \r\n               if isempty(Miss)\t\t\t\t\t\t\t\t\t% all counts are SkipNum\r\n                  Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';\r\n                  fseek(fid,-TypeSize(1),0); \t\t\t\t% undo last skip\r\n                  \r\n                  for k = 1:BufSize\r\n                     ListData{1}{j+k} = Buf(k,:);\r\n                  end\r\n                  \r\n                  j = j + BufSize;\r\n                  BufSize = floor(1.5*BufSize);\r\n               else\r\n                  if Miss(1) > 1\t\t\t\t\t\t\t\t\t% some counts are SkipNum\r\n                     Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';                     \r\n                     \r\n                     for k = 1:Miss(1)-1\r\n                        ListData{1}{j+k} = Buf2(k,:);\r\n                     end\r\n                     \r\n                     j = j + k;\r\n                  end\r\n                  \r\n                  % read in the list with the missed count\r\n                  SkipNum = Buf(Miss(1));\r\n                  j = j + 1;\r\n                  ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});\r\n                  BufSize = ceil(0.6*BufSize);\r\n               end\r\n            end\r\n         else\r\n            % list type(s), multiple properties (slow)\r\n            Data = zeros(ElementCount(i),NumProperties);\r\n            \r\n            for j = 1:ElementCount(i)\r\n         \t\tfor k = 1:NumProperties\r\n            \t\tif isempty(Type2{k})\r\n               \t\tData(j,k) = fread(fid,1,Type{k});\r\n            \t\telse\r\n               \t\ttmp = fread(fid,1,Type{k});\r\n               \t\tListData{k}{j} = fread(fid,[1,tmp],Type2{k});\r\n\t\t            end\r\n      \t\t   end\r\n      \t\tend\r\n         end\r\n      end\r\n   end\r\n   \r\n   % put data into Elements structure\r\n   for k = 1:NumProperties\r\n   \tif (~Format & ~Type(k)) | (Format & isempty(Type2{k}))\r\n      \teval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);\r\n      else\r\n      \teval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);\r\n\t\tend\r\n   end\r\nend\r\n\r\nclear Data ListData;\r\nfclose(fid);\r\n\r\nif (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2   \r\n   % find vertex element field\r\n   Name = {'vertex','Vertex','point','Point','pts','Pts'};\r\n   Names = [];\r\n   \r\n   for i = 1:length(Name)\r\n      if any(strcmp(ElementNames,Name{i}))\r\n         Names = getfield(PropertyNames,Name{i});\r\n         Name = Name{i};         \r\n         break;\r\n      end\r\n   end\r\n   \r\n   if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))\r\n      eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);\r\n   else\r\n      varargout{1} = zeros(1,3);\r\n\tend\r\n           \r\n   varargout{2} = Elements;\r\n   varargout{3} = Comments;\r\n   Elements = [];\r\n   \r\n   % find face element field\r\n   Name = {'face','Face','poly','Poly','tri','Tri'};\r\n   Names = [];\r\n   \r\n   for i = 1:length(Name)\r\n      if any(strcmp(ElementNames,Name{i}))\r\n         Names = getfield(PropertyNames,Name{i});\r\n         Name = Name{i};\r\n         break;\r\n      end\r\n   end\r\n   \r\n   if ~isempty(Names)\r\n      % find vertex indices property subfield\r\n\t   PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};           \r\n      \r\n   \tfor i = 1:length(PropertyName)\r\n      \tif any(strcmp(Names,PropertyName{i}))\r\n         \tPropertyName = PropertyName{i};\r\n\t         break;\r\n   \t   end\r\n      end\r\n      \r\n      if ~iscell(PropertyName)\r\n         % convert face index lists to triangular connectivity\r\n         eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);\r\n  \t\t\tN = length(FaceIndices);\r\n   \t\tElements = zeros(N*2,3);\r\n   \t\tExtra = 0;   \r\n\r\n\t\t\tfor k = 1:N\r\n   \t\t\tElements(k,:) = FaceIndices{k}(1:3);\r\n   \r\n   \t\t\tfor j = 4:length(FaceIndices{k})\r\n      \t\t\tExtra = Extra + 1;      \r\n\t      \t\tElements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];\r\n   \t\t\tend\r\n         end\r\n         Elements = Elements(1:N+Extra,:) + 1;\r\n      end\r\n   end\r\nelse\r\n   varargout{1} = Comments;\r\nend"
  },
  {
    "path": "datasets/evaluations/dtu_parallel/reducePts_haa.m",
    "content": "function [ptsOut,indexSet] = reducePts_haa(pts, dst)\n\n%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance\n% between points is 'dst'. Writen by abd, edited by haa, then by raje\n\nnPoints=size(pts,2);\n\nindexSet=true(nPoints,1);\nRandOrd=randperm(nPoints);\n\n%tic\nNS = KDTreeSearcher(pts');\n%toc\n\n% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big\nChunks=1:min(4e6,nPoints-1):nPoints;\nChunks(end)=nPoints;\n\nfor cChunk=1:(length(Chunks)-1)\n    Range=Chunks(cChunk):Chunks(cChunk+1);\n    \n    idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);\n    \n    for i = 1:size(idx,1)\n        id =RandOrd(i-1+Chunks(cChunk));\n        if (indexSet(id))\n            indexSet(idx{i}) = 0;\n            indexSet(id) = 1;\n        end\n    end\nend\n\nptsOut = pts(:,indexSet);\n\ndisp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);"
  },
  {
    "path": "datasets/lists/blendedmvs/low_res_all.txt",
    "content": "5c1f33f1d33e1f2e4aa6dda4\n5bfe5ae0fe0ea555e6a969ca\n5bff3c5cfe0ea555e6bcbf3a\n58eaf1513353456af3a1682a\n5bfc9d5aec61ca1dd69132a2\n5bf18642c50e6f7f8bdbd492\n5bf26cbbd43923194854b270\n5bf17c0fd439231948355385\n5be3ae47f44e235bdbbc9771\n5be3a5fb8cfdd56947f6b67c\n5bbb6eb2ea1cfa39f1af7e0c\n5ba75d79d76ffa2c86cf2f05\n5bb7a08aea1cfa39f1a947ab\n5b864d850d072a699b32f4ae\n5b6eff8b67b396324c5b2672\n5b6e716d67b396324c2d77cb\n5b69cc0cb44b61786eb959bf\n5b62647143840965efc0dbde\n5b60fa0c764f146feef84df0\n5b558a928bbfb62204e77ba2\n5b271079e0878c3816dacca4\n5b08286b2775267d5b0634ba\n5afacb69ab00705d0cefdd5b\n5af28cea59bc705737003253\n5af02e904c8216544b4ab5a2\n5aa515e613d42d091d29d300\n5c34529873a8df509ae57b58\n5c34300a73a8df509add216d\n5c1af2e2bee9a723c963d019\n5c1892f726173c3a09ea9aeb\n5c0d13b795da9479e12e2ee9\n5c062d84a96e33018ff6f0a6\n5bfd0f32ec61ca1dd69dc77b\n5bf21799d43923194842c001\n5bf3a82cd439231948877aed\n5bf03590d4392319481971dc\n5beb6e66abd34c35e18e66b9\n5be883a4f98cee15019d5b83\n5be47bf9b18881428d8fbc1d\n5bcf979a6d5f586b95c258cd\n5bce7ac9ca24970bce4934b6\n5bb8a49aea1cfa39f1aa7f75\n5b78e57afc8fcf6781d0c3ba\n5b21e18c58e2823a67a10dd8\n5b22269758e2823a67a3bd03\n5b192eb2170cf166458ff886\n5ae2e9c5fe405c5076abc6b2\n5adc6bd52430a05ecb2ffb85\n5ab8b8e029f5351f7f2ccf59\n5abc2506b53b042ead637d86\n5ab85f1dac4291329b17cb50\n5a969eea91dfc339a9a3ad2c\n5a8aa0fab18050187cbe060e\n5a7d3db14989e929563eb153\n5a69c47d0d5d0a7f3b2e9752\n5a618c72784780334bc1972d\n5a6464143d809f1d8208c43c\n5a588a8193ac3d233f77fbca\n5a57542f333d180827dfc132\n5a572fd9fc597b0478a81d14\n5a563183425d0f5186314855\n5a4a38dad38c8a075495b5d2\n5a48d4b2c7dab83a7d7b9851\n5a489fb1c7dab83a7d7b1070\n5a48ba95c7dab83a7d7b44ed\n5a3ca9cb270f0e3f14d0eddb\n5a3cb4e4270f0e3f14d12f43\n5a3f4aba5889373fbbc5d3b5\n5a0271884e62597cdee0d0eb\n59e864b2a9e91f2c5529325f\n599aa591d5b41f366fed0d58\n59350ca084b7f26bf5ce6eb8\n59338e76772c3e6384afbb15\n5c20ca3a0843bc542d94e3e2\n5c1dbf200843bc542d8ef8c4\n5c1b1500bee9a723c96c3e78\n5bea87f4abd34c35e1860ab5\n5c2b3ed5e611832e8aed46bf\n57f8d9bbe73f6760f10e916a\n5bf7d63575c26f32dbf7413b\n5be4ab93870d330ff2dce134\n5bd43b4ba6b28b1ee86b92dd\n5bccd6beca24970bce448134\n5bc5f0e896b66a2cd8f9bd36\n5b908d3dc6ab78485f3d24a9\n5b2c67b5e0878c381608b8d8\n5b4933abf2b5f44e95de482a\n5b3b353d8d46a939f93524b9\n5acf8ca0f3d8a750097e4b15\n5ab8713ba3799a1d138bd69a\n5aa235f64a17b335eeaf9609\n5aa0f9d7a9efce63548c69a1\n5a8315f624b8e938486e0bd8\n5a48c4e9c7dab83a7d7b5cc7\n59ecfd02e225f6492d20fcc9\n59f87d0bfa6280566fb38c9a\n59f363a8b45be22330016cad\n59f70ab1e5c5d366af29bf3e\n59e75a2ca9e91f2c5526005d\n5947719bf1b45630bd096665\n5947b62af1b45630bd0c2a02\n59056e6760bb961de55f3501\n58f7f7299f5b5647873cb110\n58cf4771d0f5fb221defe6da\n58d36897f387231e6c929903\n58c4bb4f4a69c55606122be4\n5b7a3890fc8fcf6781e2593a\n5c189f2326173c3a09ed7ef3\n5b950c71608de421b1e7318f\n5a6400933d809f1d8200af15\n59d2657f82ca7774b1ec081d\n5ba19a8a360c7c30c1c169df\n59817e4a1bd4b175e7038d19"
  },
  {
    "path": "datasets/lists/blendedmvs/val.txt",
    "content": "5b7a3890fc8fcf6781e2593a\n5c189f2326173c3a09ed7ef3\n5b950c71608de421b1e7318f\n5a6400933d809f1d8200af15\n59d2657f82ca7774b1ec081d\n5ba19a8a360c7c30c1c169df\n59817e4a1bd4b175e7038d19"
  },
  {
    "path": "datasets/lists/dtu/test.txt",
    "content": "scan1\nscan4\nscan9\nscan10\nscan11\nscan12\nscan13\nscan15\nscan23\nscan24\nscan29\nscan32\nscan33\nscan34\nscan48\nscan49\nscan62\nscan75\nscan77\nscan110\nscan114\nscan118"
  },
  {
    "path": "datasets/lists/dtu/train.txt",
    "content": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan45\nscan46\nscan47\nscan50\nscan51\nscan52\nscan53\nscan55\nscan57\nscan58\nscan60\nscan61\nscan63\nscan64\nscan65\nscan68\nscan69\nscan70\nscan71\nscan72\nscan74\nscan76\nscan83\nscan84\nscan85\nscan87\nscan88\nscan89\nscan90\nscan91\nscan92\nscan93\nscan94\nscan95\nscan96\nscan97\nscan98\nscan99\nscan100\nscan101\nscan102\nscan103\nscan104\nscan105\nscan107\nscan108\nscan109\nscan111\nscan112\nscan113\nscan115\nscan116\nscan119\nscan120\nscan121\nscan122\nscan123\nscan124\nscan125\nscan126\nscan127\nscan128"
  },
  {
    "path": "datasets/lists/dtu/val.txt",
    "content": "scan3\nscan5\nscan17\nscan21\nscan28\nscan35\nscan37\nscan38\nscan40\nscan43\nscan56\nscan59\nscan66\nscan67\nscan82\nscan86\nscan106\nscan117"
  },
  {
    "path": "datasets/lists/tnt/advanced.txt",
    "content": "Auditorium\nBallroom\nCourtroom\nMuseum\nPalace\nTemple"
  },
  {
    "path": "datasets/lists/tnt/intermediate.txt",
    "content": "Family\nHorse\nFrancis\nLighthouse\nM60\nPanther\nPlayground\nTrain"
  },
  {
    "path": "datasets/tnt.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Data preprocessing and organization for Tanks and Temples dataset.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom torch.utils.data import Dataset\n\nfrom datasets.data_io import *\n\n\nclass TNTDataset(Dataset):\n    def __init__(self, root_dir, list_file, split, n_views, **kwargs):\n        super(TNTDataset, self).__init__()\n\n        self.root_dir = root_dir\n        self.list_file = list_file\n        self.split = split\n        self.n_views = n_views\n\n        self.cam_mode = kwargs.get(\"cam_mode\", \"origin\")    # origin / short_range\n        if self.cam_mode == 'short_range': assert self.split == \"intermediate\"\n        self.img_mode = kwargs.get(\"img_mode\", \"resize\")    # resize / crop\n\n        self.total_depths = 192\n        self.depth_interval_table = {\n            # intermediate\n            'Family': 2.5e-3, 'Francis': 1e-2, 'Horse': 1.5e-3, 'Lighthouse': 1.5e-2, 'M60': 5e-3, 'Panther': 5e-3, 'Playground': 7e-3, 'Train': 5e-3, \n            # advanced\n            'Auditorium': 3e-2, 'Ballroom': 2e-2, 'Courtroom': 2e-2, 'Museum': 2e-2, 'Palace': 1e-2, 'Temple': 1e-2\n        }\n        self.img_wh = kwargs.get(\"img_wh\", (-1, 1024))\n\n        self.metas = self.build_metas()\n\n\n    def build_metas(self):\n        metas = []\n\n        with open(os.path.join(self.list_file)) as f:\n            scans = [line.rstrip() for line in f.readlines()]\n\n        for scan in scans:\n            with open(os.path.join(self.root_dir, self.split, scan, 'pair.txt')) as f:\n                num_viewpoint = int(f.readline())\n                for view_idx in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    if len(src_views) != 0:\n                        metas += [(scan, -1, ref_view, src_views)]\n        return metas\n\n   \n    def read_cam_file(self, filename):\n        with open(filename) as f:\n            lines = [line.rstrip() for line in f.readlines()]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')\n        extrinsics = extrinsics.reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')\n        intrinsics = intrinsics.reshape((3, 3))\n        \n        depth_min = float(lines[11].split()[0])\n        depth_max = float(lines[11].split()[-1])\n\n        return intrinsics, extrinsics, depth_min, depth_max\n\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n\n    def scale_tnt_input(self, intrinsics, img):\n        if self.img_mode == \"crop\":\n            intrinsics[1,2] = intrinsics[1,2] - 28  # 1080 -> 1024\n            img = img[28:1080-28, :, :]\n        elif self.img_mode == \"resize\": \n            height, width = img.shape[:2]\n\n            max_w, max_h = self.img_wh[0], self.img_wh[1]\n            if max_w == -1:\n                max_w = width\n\n            img = cv2.resize(img, (max_w, max_h))\n\n            scale_w = 1.0 * max_w / width\n            intrinsics[0, :] *= scale_w\n            scale_h = 1.0 * max_h / height\n            intrinsics[1, :] *= scale_h\n\n        return intrinsics, img\n\n\n    def __len__(self):\n        return len(self.metas)\n\n\n    def __getitem__(self, idx):\n        scan, _, ref_view, src_views = self.metas[idx]\n        view_ids = [ref_view] + src_views[:self.n_views-1]\n\n        imgs = []\n        depth_min = None\n        depth_max = None\n\n        proj_matrices_0 = []\n        proj_matrices_1 = []\n        proj_matrices_2 = []\n        proj_matrices_3 = []\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.root_dir, self.split, scan, f'images/{vid:08d}.jpg')\n            if self.cam_mode == 'short_range':\n                # can only use for Intermediate\n                proj_mat_filename = os.path.join(self.root_dir, self.split, scan, f'cams_{scan.lower()}/{vid:08d}_cam.txt')\n            elif self.cam_mode == 'origin':\n                proj_mat_filename = os.path.join(self.root_dir, self.split, scan, f'cams/{vid:08d}_cam.txt')\n\n            img = self.read_img(img_filename)\n\n            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)\n            intrinsics, img = self.scale_tnt_input(intrinsics, img)\n            imgs.append(img.transpose(2,0,1))\n\n            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n\n            intrinsics[:2,:] *= 0.125\n            proj_mat_0[0,:4,:4] = extrinsics.copy()\n            proj_mat_0[1,:3,:3] = intrinsics.copy()\n            int_mat_0 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_1[0,:4,:4] = extrinsics.copy()\n            proj_mat_1[1,:3,:3] = intrinsics.copy()\n            int_mat_1 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_2[0,:4,:4] = extrinsics.copy()\n            proj_mat_2[1,:3,:3] = intrinsics.copy()\n            int_mat_2 = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_3[0,:4,:4] = extrinsics.copy()\n            proj_mat_3[1,:3,:3] = intrinsics.copy()\n            int_mat_3 = intrinsics.copy() \n\n            proj_matrices_0.append(proj_mat_0)\n            proj_matrices_1.append(proj_mat_1)\n            proj_matrices_2.append(proj_mat_2)\n            proj_matrices_3.append(proj_mat_3)\n\n            # reference view\n            if i == 0:\n                depth_min =  depth_min_\n                if self.cam_mode == 'short_range':\n                    depth_max = depth_min + self.total_depths * self.depth_interval_table[scan]\n                elif self.cam_mode == 'origin':\n                    depth_max = depth_max_\n\n        proj={}\n        proj['stage1'] = np.stack(proj_matrices_0)\n        proj['stage2'] = np.stack(proj_matrices_1)\n        proj['stage3'] = np.stack(proj_matrices_2)\n        proj['stage4'] = np.stack(proj_matrices_3)\n\n        intrinsics_matrices = {\n            \"stage1\": int_mat_0,\n            \"stage2\": int_mat_1,\n            \"stage3\": int_mat_2,\n            \"stage4\": int_mat_3\n        }\n\n        sample = {\n            \"imgs\": imgs,\n            \"proj_matrices\": proj,\n            \"intrinsics_matrices\": intrinsics_matrices,\n            \"depth_values\": np.array([depth_min, depth_max], dtype=np.float32),\n            \"filename\": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + \"{}\"\n        }\n\n        return sample"
  },
  {
    "path": "fusions/dtu/_open3d.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Point cloud fusion strategy for DTU dataset based on Open3D Library.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport torch\nimport numpy as np\nimport sys\nimport argparse\nimport errno, os\nimport glob\nimport os.path as osp\nimport re\nimport cv2\nfrom PIL import Image\nimport gc\nimport open3d as o3d\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\n\nparser = argparse.ArgumentParser(description='Depth fusion with consistency check.')\nparser.add_argument('--root_path', type=str, default='[/path/to/]dtu-test-1200')\nparser.add_argument('--depth_path', type=str, default='')\nparser.add_argument('--data_list', type=str, default='')\nparser.add_argument('--ply_path', type=str, default='')\nparser.add_argument('--dist_thresh', type=float, default=0.001)\nparser.add_argument('--prob_thresh', type=float, default=0.6)\nparser.add_argument('--num_consist', type=int, default=10)\nparser.add_argument('--device', type=str, default='cpu')\n\nargs = parser.parse_args()\n\n\ndef homo_warping(src_fea, src_proj, ref_proj, depth_values):\n    # src_fea: [B, C, H, W]\n    # src_proj: [B, 4, 4]\n    # ref_proj: [B, 4, 4]\n    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]\n    # out: [B, C, Ndepth, H, W]\n    batch, channels = src_fea.shape[0], src_fea.shape[1]\n    height, width = src_fea.shape[2], src_fea.shape[3]\n\n    with torch.no_grad():\n        proj = torch.matmul(src_proj, torch.inverse(ref_proj))\n        rot = proj[:, :3, :3]  # [B,3,3]\n        trans = proj[:, :3, 3:4]  # [B,3,1]\n\n        y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),\n                               torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])\n        y, x = y.contiguous(), x.contiguous()\n        y, x = y.view(height * width), x.view(height * width)\n        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]\n        xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)  # [B, 3, H*W]\n        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]\n\n        rot_depth_xyz = rot_xyz.unsqueeze(2) * depth_values.view(-1, 1, 1, height*width)  # [B, 3, 1, H*W]\n\n        proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)  # [B, 3, Ndepth, H*W]\n        proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]\n        proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1\n        proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1\n        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]\n        grid = proj_xy\n\n    warped_src_fea = F.grid_sample(src_fea, grid.view(batch,  height, width, 2), mode='bilinear',\n                                   padding_mode='zeros')\n    warped_src_fea = warped_src_fea.view(batch, channels, height, width)\n\n    return warped_src_fea\n\n\ndef generate_points_from_depth(depth, proj):\n    '''\n    :param depth: (B, 1, H, W)\n    :param proj: (B, 4, 4)\n    :return: point_cloud (B, 3, H, W)\n    '''\n    batch, height, width = depth.shape[0], depth.shape[2], depth.shape[3]\n    inv_proj = torch.inverse(proj)\n\n    rot = inv_proj[:, :3, :3]  # [B,3,3]\n    trans = inv_proj[:, :3, 3:4]  # [B,3,1]\n\n    y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth.device),\n                           torch.arange(0, width, dtype=torch.float32, device=depth.device)])\n    y, x = y.contiguous(), x.contiguous()\n    y, x = y.view(height * width), x.view(height * width)\n    xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]\n    xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)  # [B, 3, H*W]\n    rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]\n    rot_depth_xyz = rot_xyz * depth.view(batch, 1, -1)\n    proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1)  # [B, 3, H*W]\n    proj_xyz = proj_xyz.view(batch, 3, height, width)\n\n    return proj_xyz\n\n\ndef mkdir_p(path):\n    try:\n        os.makedirs(path)\n    except OSError as exc:\n        if exc.errno == errno.EEXIST and os.path.isdir(path):\n            pass\n        else:\n            raise\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\ndef write_pfm(file, image, scale=1):\n    file = open(file, 'wb')\n    color = None\n    if image.dtype.name != 'float32':\n        raise Exception('Image dtype must be float32.')\n\n    image = np.flipud(image)\n\n    if len(image.shape) == 3 and image.shape[2] == 3: # color image\n        color = True\n    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale\n        color = False\n    else:\n        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')\n\n    file.write('PF\\n'.encode() if color else 'Pf\\n'.encode())\n    file.write('%d %d\\n'.encode() % (image.shape[1], image.shape[0]))\n\n    endian = image.dtype.byteorder\n\n    if endian == '<' or endian == '=' and sys.byteorder == 'little':\n        scale = -scale\n\n    file.write('%f\\n'.encode() % scale)\n\n    image_string = image.tostring()\n    file.write(image_string)\n    file.close()\n\n\ndef write_ply(file, points):\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points[:, :3])\n    pcd.colors = o3d.utility.Vector3dVector(points[:, 3:] / 255.)\n    o3d.io.write_point_cloud(file, pcd, write_ascii=False)\n\n\ndef filter_depth(ref_depth, src_depths, ref_proj, src_projs):\n    '''\n    :param ref_depth: (1, 1, H, W)\n    :param src_depths: (B, 1, H, W)\n    :param ref_proj: (1, 4, 4)\n    :param src_proj: (B, 4, 4)\n    :return: ref_pc: (1, 3, H, W), aligned_pcs: (B, 3, H, W), dist: (B, 1, H, W)\n    '''\n\n    ref_pc = generate_points_from_depth(ref_depth, ref_proj)\n    src_pcs = generate_points_from_depth(src_depths, src_projs)\n\n    aligned_pcs = homo_warping(src_pcs, src_projs, ref_proj, ref_depth)\n\n    x_2 = (ref_pc[:, 0] - aligned_pcs[:, 0])**2\n    y_2 = (ref_pc[:, 1] - aligned_pcs[:, 1])**2\n    z_2 = (ref_pc[:, 2] - aligned_pcs[:, 2])**2\n    dist = torch.sqrt(x_2 + y_2 + z_2).unsqueeze(1)\n\n    return ref_pc, aligned_pcs, dist\n\n\ndef parse_cameras(path):\n    cam_txt = open(path).readlines()\n    f = lambda xs: list(map(lambda x: list(map(float, x.strip().split())), xs))\n\n    extr_mat = f(cam_txt[1:5])\n    intr_mat = f(cam_txt[7:10])\n\n    extr_mat = np.array(extr_mat, np.float32)\n    intr_mat = np.array(intr_mat, np.float32)\n\n    return extr_mat, intr_mat\n\n\ndef load_data(root_path, depth_path, scene_name, thresh):\n\n    depths = []\n    projs = []\n    rgbs = []\n\n    for view in range(49):\n        img_filename = \"{}/{}/images/{:08d}.jpg\".format(depth_path, scene_name, view)\n        cam_filename = \"{}/{}/cams/{:08d}_cam.txt\".format(depth_path, scene_name, view)\n        depth_filename = \"{}/{}/depth_est/{:08d}.pfm\".format(depth_path, scene_name, view)\n        confidence_filename = \"{}/{}/confidence/{:08d}.pfm\".format(depth_path, scene_name, view)\n\n\n        extr_mat, intr_mat = parse_cameras(cam_filename)\n        proj_mat = np.eye(4)\n        proj_mat[:3, :4] = np.dot(intr_mat[:3, :3], extr_mat[:3, :4])\n        projs.append(torch.from_numpy(proj_mat))\n\n        dep_map, _ = read_pfm(depth_filename)\n        h, w = dep_map.shape\n        conf_map, _ = read_pfm(confidence_filename)\n        conf_map = cv2.resize(conf_map, (w, h), interpolation=cv2.INTER_LINEAR)\n\n        dep_map = dep_map * (conf_map>thresh).astype(np.float32)\n        depths.append(torch.from_numpy(dep_map).unsqueeze(0))\n\n        rgb = np.array(Image.open(img_filename))\n        rgbs.append(rgb)\n\n    depths = torch.stack(depths).float()\n    projs = torch.stack(projs).float()\n    if args.device == 'cuda' and torch.cuda.is_available():\n        depths = depths.cuda()\n        projs = projs.cuda()\n\n    return depths, projs, rgbs\n\n\ndef extract_points(pc, mask, rgb):\n    pc = pc.cpu().numpy()\n    mask = mask.cpu().numpy()\n\n    mask = np.reshape(mask, (-1,))\n    pc = np.reshape(pc, (-1, 3))\n    rgb = np.reshape(rgb, (-1, 3))\n\n    points = pc[np.where(mask)]\n    colors = rgb[np.where(mask)]\n\n    points_with_color = np.concatenate([points, colors], axis=1)\n\n    return points_with_color\n\n\ndef open3d_filter():\n    with torch.no_grad():\n        mkdir_p(args.ply_path)\n        all_scenes = open(args.data_list, 'r').readlines()\n        all_scenes = list(map(str.strip, all_scenes))\n\n        for i, scene in enumerate(all_scenes):\n\n            print(\"{}/{} {}:\".format(i, len(all_scenes), scene), '------------------------')\n\n            depths, projs, rgbs = load_data(args.root_path, args.depth_path, scene, args.prob_thresh)\n            tot_frame = depths.shape[0]\n            height, width = depths.shape[2], depths.shape[3]\n            points = []\n\n            print('Scene: {} total: {} frames'.format(scene, tot_frame))\n            for i in range(tot_frame):\n                pc_buff = torch.zeros((3, height, width), device=depths.device, dtype=depths.dtype)\n                val_cnt = torch.zeros((1, height, width), device=depths.device, dtype=depths.dtype)\n                j = 0\n                batch_size = 20\n\n                while True:\n                    ref_pc, pcs, dist = filter_depth(ref_depth=depths[i:i+1], src_depths=depths[j:min(j+batch_size, tot_frame)],\n                                                    ref_proj=projs[i:i+1], src_projs=projs[j:min(j+batch_size, tot_frame)])\n                    masks = (dist < args.dist_thresh).float()\n                    masked_pc = pcs * masks\n                    pc_buff += masked_pc.sum(dim=0, keepdim=False)\n                    val_cnt += masks.sum(dim=0, keepdim=False)\n\n                    j += batch_size\n                    if j >= tot_frame:\n                        break\n\n                final_mask = (val_cnt >= args.num_consist).squeeze(0)\n                avg_points = torch.div(pc_buff, val_cnt).permute(1, 2, 0)\n\n                final_pc = extract_points(avg_points, final_mask, rgbs[i])\n                points.append(final_pc)\n                if i==0 or i==tot_frame-1:\n                    print('Processing {} {}/{} ...'.format(scene, i+1, tot_frame))\n\n            ply_id = int(scene[4:])\n            write_ply('{}/mvsnet{:03d}.ply'.format(args.ply_path, ply_id), np.concatenate(points, axis=0))\n            del points, depths, rgbs, projs\n\n            gc.collect()\n\n            print('Save {}/mvsnet{:03d}.ply successful.'.format(args.ply_path, ply_id))\n\n\nif __name__ == '__main__':\n    open3d_filter()\n"
  },
  {
    "path": "fusions/dtu/gipuma.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Point cloud fusion strategy for DTU dataset: Gipuma (fusibile).\n#     Refer to: https://github.com/YoYo000/MVSNet/blob/master/mvsnet/depthfusion.py\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nfrom __future__ import print_function\n\nimport os, re, sys, shutil\nfrom struct import *\nimport numpy as np\nimport argparse\nimport cv2\nfrom tensorflow.python.lib.io import file_io\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--root_dir', type=str, default='[/path/to]/dtu-test-1200', help='root directory of dtu dataset')\nparser.add_argument('--list_file', type=str, default='datasets/lists/dtu/train.txt', help='file contains the scans list')\n\nparser.add_argument('--depth_folder', type=str, default = './outputs/')\nparser.add_argument('--out_folder', type=str, default = 'fusibile_fused')\nparser.add_argument('--plydir', type=str, default='')\nparser.add_argument('--quandir', type=str, default='')\nparser.add_argument('--fusibile_exe_path', type=str, default = 'fusion/fusibile')\nparser.add_argument('--prob_threshold', type=float, default = '0.8')\nparser.add_argument('--disp_threshold', type=float, default = '0.13')\nparser.add_argument('--num_consistent', type=float, default = '3')\nparser.add_argument('--downsample_factor', type=int, default='1')\n\nargs = parser.parse_args()\n\n\n# preprocess ====================================\n\ndef load_cam(file, interval_scale=1):\n    \"\"\" read camera txt file \"\"\"\n    cam = np.zeros((2, 4, 4))\n    words = file.read().split()\n    # read extrinsic\n    for i in range(0, 4):\n        for j in range(0, 4):\n            extrinsic_index = 4 * i + j + 1\n            cam[0][i][j] = words[extrinsic_index]\n\n    # read intrinsic\n    for i in range(0, 3):\n        for j in range(0, 3):\n            intrinsic_index = 3 * i + j + 18\n            cam[1][i][j] = words[intrinsic_index]\n\n    if len(words) == 29:\n        cam[1][3][0] = words[27]\n        cam[1][3][1] = float(words[28]) * interval_scale\n        cam[1][3][2] = 1100\n        cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]\n    elif len(words) == 30:\n        cam[1][3][0] = words[27]\n        cam[1][3][1] = float(words[28]) * interval_scale\n        cam[1][3][2] = words[29]\n        cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2]\n    elif len(words) == 31:\n        cam[1][3][0] = words[27]\n        cam[1][3][1] = float(words[28]) * interval_scale\n        cam[1][3][2] = words[29]\n        cam[1][3][3] = words[30]\n    else:\n        cam[1][3][0] = 0\n        cam[1][3][1] = 0\n        cam[1][3][2] = 0\n        cam[1][3][3] = 0\n\n    return cam\n\n\ndef load_pfm(file):\n    color = None\n    width = None\n    height = None\n    scale = None\n    data_type = None\n    header = file.readline().decode('UTF-8').rstrip()\n\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('UTF-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n    # scale = float(file.readline().rstrip())\n    scale = float((file.readline()).decode('UTF-8').rstrip())\n    if scale < 0: # little-endian\n        data_type = '<f'\n    else:\n        data_type = '>f' # big-endian\n    data_string = file.read()\n    data = np.fromstring(data_string, data_type)\n    shape = (height, width, 3) if color else (height, width)\n    data = np.reshape(data, shape)\n    data = cv2.flip(data, 0)\n    return data\n\n\ndef write_pfm(file, image, scale=1):\n    file = file_io.FileIO(file, mode='wb')\n    color = None\n\n    if image.dtype.name != 'float32':\n        raise Exception('Image dtype must be float32.')\n\n    image = np.flipud(image)\n\n    if len(image.shape) == 3 and image.shape[2] == 3: # color image\n        color = True\n    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale\n        color = False\n    else:\n        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')\n\n    file.write('PF\\n' if color else 'Pf\\n')\n    file.write('%d %d\\n' % (image.shape[1], image.shape[0]))\n\n    endian = image.dtype.byteorder\n\n    if endian == '<' or endian == '=' and sys.byteorder == 'little':\n        scale = -scale\n\n    file.write('%f\\n' % scale)\n\n    image_string = image.tostring()\n    file.write(image_string)\n\n    file.close()\n\n# ================================================\n\n\ndef read_gipuma_dmb(path):\n    '''read Gipuma .dmb format image'''\n\n    with open(path, \"rb\") as fid:\n        \n        image_type = unpack('<i', fid.read(4))[0]\n        height = unpack('<i', fid.read(4))[0]\n        width = unpack('<i', fid.read(4))[0]\n        channel = unpack('<i', fid.read(4))[0]\n        \n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channel), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n\n\ndef write_gipuma_dmb(path, image):\n    '''write Gipuma .dmb format image'''\n    \n    image_shape = np.shape(image)\n    width = image_shape[1]\n    height = image_shape[0]\n    if len(image_shape) == 3:\n        channels = image_shape[2]\n    else:\n        channels = 1\n\n    if len(image_shape) == 3:\n        image = np.transpose(image, (2, 0, 1)).squeeze()\n\n    with open(path, \"wb\") as fid:\n        # fid.write(pack(1))\n        fid.write(pack('<i', 1))\n        fid.write(pack('<i', height))\n        fid.write(pack('<i', width))\n        fid.write(pack('<i', channels))\n        image.tofile(fid)\n    return \n\n\ndef mvsnet_to_gipuma_dmb(in_path, out_path):\n    '''convert mvsnet .pfm output to Gipuma .dmb format'''\n    \n    image = load_pfm(open(in_path))\n    write_gipuma_dmb(out_path, image)\n\n    return \n\n\ndef mvsnet_to_gipuma_cam(in_path, out_path):\n    '''convert mvsnet camera to gipuma camera format'''\n\n    cam = load_cam(open(in_path))\n\n    extrinsic = cam[0:4][0:4][0]\n    intrinsic = cam[0:4][0:4][1]\n    intrinsic[3][0] = 0\n    intrinsic[3][1] = 0\n    intrinsic[3][2] = 0\n    intrinsic[3][3] = 0\n\n    intrinsic[:2, :] /= args.downsample_factor\n\n    projection_matrix = np.matmul(intrinsic, extrinsic)\n    projection_matrix = projection_matrix[0:3][:]\n    \n    f = open(out_path, \"w\")\n    for i in range(0, 3):\n        for j in range(0, 4):\n            f.write(str(projection_matrix[i][j]) + ' ')\n        f.write('\\n')\n    f.write('\\n')\n    f.close()\n\n    return\n\n\ndef fake_gipuma_normal(in_depth_path, out_normal_path):\n    \n    depth_image = read_gipuma_dmb(in_depth_path)\n    image_shape = np.shape(depth_image)\n\n    normal_image = np.ones_like(depth_image)\n    normal_image = np.reshape(normal_image, (image_shape[0], image_shape[1], 1))\n    normal_image = np.tile(normal_image, [1, 1, 3])\n    normal_image = normal_image / 1.732050808\n\n    mask_image = np.squeeze(np.where(depth_image > 0, 1, 0))\n    mask_image = np.reshape(mask_image, (image_shape[0], image_shape[1], 1))\n    mask_image = np.tile(mask_image, [1, 1, 3])\n    mask_image = np.float32(mask_image)\n\n    normal_image = np.multiply(normal_image, mask_image)\n    normal_image = np.float32(normal_image)\n\n    write_gipuma_dmb(out_normal_path, normal_image)\n    return \n\n\ndef mvsnet_to_gipuma(scan_folder, scan, root_dir, gipuma_point_folder):\n    \n    image_folder = os.path.join(root_dir, 'Rectified', scan)\n    cam_folder = os.path.join(root_dir, 'Cameras')\n    depth_folder = os.path.join(scan_folder, 'depth_est')\n\n    gipuma_cam_folder = os.path.join(gipuma_point_folder, 'cams')\n    gipuma_image_folder = os.path.join(gipuma_point_folder, 'images')\n    if not os.path.isdir(gipuma_point_folder):\n        os.mkdir(gipuma_point_folder)\n    if not os.path.isdir(gipuma_cam_folder):\n        os.mkdir(gipuma_cam_folder)\n    if not os.path.isdir(gipuma_image_folder):\n        os.mkdir(gipuma_image_folder)\n\n    # convert cameras \n    for view in range(0,49):\n        in_cam_file = os.path.join(cam_folder, \"{:08d}_cam.txt\".format(view))\n        out_cam_file = os.path.join(gipuma_cam_folder, \"{:08d}.png.P\".format(view))\n        mvsnet_to_gipuma_cam(in_cam_file, out_cam_file)\n\n    # copy images to gipuma image folder    \n    for view in range(0,49):\n        in_image_file = os.path.join(image_folder, \"rect_{:03d}_3_r5000.png\".format(view+1))# Our image start from 1\n        out_image_file = os.path.join(gipuma_image_folder, \"{:08d}.png\".format(view))\n        # shutil.copy(in_image_file, out_image_file)\n\n        in_image = cv2.imread(in_image_file)\n        out_image = cv2.resize(in_image, None, fx=1.0/args.downsample_factor, fy=1.0/args.downsample_factor, interpolation=cv2.INTER_LINEAR)\n        cv2.imwrite(out_image_file, out_image)\n\n    # convert depth maps and fake normal maps\n    gipuma_prefix = '2333__'\n    for view in range(0,49):\n\n        sub_depth_folder = os.path.join(gipuma_point_folder, gipuma_prefix+\"{:08d}\".format(view))\n        if not os.path.isdir(sub_depth_folder):\n            os.mkdir(sub_depth_folder)\n        in_depth_pfm = os.path.join(depth_folder, \"{:08d}_prob_filtered.pfm\".format(view))\n        out_depth_dmb = os.path.join(sub_depth_folder, 'disp.dmb')\n        fake_normal_dmb = os.path.join(sub_depth_folder, 'normals.dmb')\n        mvsnet_to_gipuma_dmb(in_depth_pfm, out_depth_dmb)\n        fake_gipuma_normal(out_depth_dmb, fake_normal_dmb)\n\n\ndef probability_filter(scan_folder, prob_threshold):\n    depth_folder = os.path.join(scan_folder, 'depth_est')\n    prob_folder = os.path.join(scan_folder, 'confidence')\n    \n    # convert cameras \n    for view in range(0,49):\n        init_depth_map_path = os.path.join(depth_folder, \"{:08d}.pfm\".format(view)) # New dataset outputs depth start from 0.\n        prob_map_path = os.path.join(prob_folder, \"{:08d}.pfm\".format(view)) # Same as above\n        out_depth_map_path = os.path.join(depth_folder, \"{:08d}_prob_filtered.pfm\".format(view)) # Gipuma start from 0\n\n        depth_map = load_pfm(open(init_depth_map_path))\n        prob_map = load_pfm(open(prob_map_path))\n        depth_map[prob_map < prob_threshold] = 0\n        write_pfm(out_depth_map_path, depth_map)\n\n\ndef depth_map_fusion(point_folder, fusibile_exe_path, disp_thresh, num_consistent):\n\n    cam_folder = os.path.join(point_folder, 'cams')\n    image_folder = os.path.join(point_folder, 'images')\n    depth_min = 0.001\n    depth_max = 100000\n    normal_thresh = 360\n\n    cmd = fusibile_exe_path\n    cmd = cmd + ' -input_folder ' + point_folder + '/'\n    cmd = cmd + ' -p_folder ' + cam_folder + '/'\n    cmd = cmd + ' -images_folder ' + image_folder + '/'\n    cmd = cmd + ' --depth_min=' + str(depth_min)\n    cmd = cmd + ' --depth_max=' + str(depth_max)\n    cmd = cmd + ' --normal_thresh=' + str(normal_thresh)\n    cmd = cmd + ' --disp_thresh=' + str(disp_thresh)\n    cmd = cmd + ' --num_consistent=' + str(num_consistent)\n    print (cmd)\n    os.system(cmd)\n\n    return \n\n\ndef collectPly(point_folder, scan_id):\n    model_name = 'final3d_model.ply'\n    model_dir = [item for item in os.listdir(point_folder) if item.startswith(\"consistencyCheck\")][-1]\n\n    old = os.path.join(point_folder, model_dir, model_name)\n    fresh = os.path.join(args.plydir, \"mvsnet\") + scan_id.zfill(3) + \".ply\"\n    shutil.move(old, fresh)\n\n\nif __name__ == '__main__':\n\n    root_dir = args.root_dir\n    depth_folder = args.depth_folder\n    out_folder = args.out_folder\n    fusibile_exe_path = args.fusibile_exe_path\n    prob_threshold = args.prob_threshold\n    disp_threshold = args.disp_threshold\n    num_consistent = args.num_consistent\n\n    # Read test list\n    testlist = args.list_file\n    with open(testlist) as f:\n        scans = f.readlines()\n        scans = [line.rstrip() for line in scans]\n\n    print(\"Start Gipuma(GPU) fusion!\")\n\n    if not os.path.isdir(args.plydir):\n        os.mkdir(args.plydir)\n\n    # Fusion\n    for i, scan in enumerate(scans):\n        print(\"{}/{} {}:\".format(i, len(scans), scan), '------------------------')\n\n        scan_folder = os.path.join(depth_folder, scan)\n        fusibile_workspace = os.path.join(depth_folder, out_folder, scan)\n\n        if not os.path.isdir(os.path.join(depth_folder, out_folder)):\n            os.mkdir(os.path.join(depth_folder, out_folder))\n\n        if not os.path.isdir(fusibile_workspace):\n            os.mkdir(fusibile_workspace)\n\n        # probability filtering\n        print ('filter depth map with probability map')\n        probability_filter(scan_folder, prob_threshold)\n\n        # convert to gipuma format\n        print ('Convert mvsnet output to gipuma input')\n        mvsnet_to_gipuma(scan_folder, scan, root_dir, fusibile_workspace)\n\n        # depth map fusion with gipuma \n        print ('Run depth map fusion & filter')\n        depth_map_fusion(fusibile_workspace, fusibile_exe_path, disp_threshold, num_consistent)\n\n        # collect .ply results to summary folder\n        print('Collect {} ply'.format(scan))\n        collectPly(fusibile_workspace, scan[4:])\n\n    print(\"Gipuma(GPU) fusion done!\")\n    shutil.rmtree(os.path.join(depth_folder, out_folder))\n    print(\"fusibile_fused remove done!\")"
  },
  {
    "path": "fusions/dtu/pcd.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Point cloud fusion strategy for DTU dataset: Basic PCD.\n#     Refer to: https://github.com/xy-guo/MVSNet_pytorch/blob/master/eval.py\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport argparse, os, sys, cv2, re, logging, time\nimport numpy as np\nfrom plyfile import PlyData, PlyElement\nfrom PIL import Image\n\nfrom multiprocessing import Pool\nfrom functools import partial\nimport signal\n\n\nparser = argparse.ArgumentParser(description='filter, and fuse')\n\nparser.add_argument('--testpath', default='[/path/to]/dtu-test-1200', help='testing data dir for some scenes')\nparser.add_argument('--testlist', default=\"datasets/lists/dtu/test.txt\", help='testing scene list')\n\nparser.add_argument('--outdir', default='./outputs/[exp_name]', help='output dir')\nparser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs')\nparser.add_argument('--nolog', action='store_true', help='do not logging into .log file')\nparser.add_argument('--plydir', default='./outputs/[exp_name]/pcd_fusion_plys/', help='output dir')\n\nparser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker')\n\nparser.add_argument('--conf', type=float, default=0.9, help='prob confidence')\nparser.add_argument('--thres_view', type=int, default=5, help='threshold of num view')\n\nargs = parser.parse_args()\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\ndef read_camera_parameters(filename):\n    with open(filename) as f:\n        lines = f.readlines()\n        lines = [line.rstrip() for line in lines]\n    # extrinsics: line [1,5), 4x4 matrix\n    extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n    # intrinsics: line [7-10), 3x3 matrix\n    intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n    return intrinsics, extrinsics\n\n\ndef read_img(filename):\n    img = Image.open(filename)\n    # scale 0~255 to 0~1\n    np_img = np.array(img, dtype=np.float32) / 255.\n    return np_img\n\n\ndef read_mask(filename):\n    return read_img(filename) > 0.5\n\n\ndef save_mask(filename, mask):\n    assert mask.dtype == np.bool\n    mask = mask.astype(np.uint8) * 255\n    Image.fromarray(mask).save(filename)\n\n\ndef read_pair_file(filename):\n    data = []\n    with open(filename) as f:\n        num_viewpoint = int(f.readline())\n        # 49 viewpoints\n        for view_idx in range(num_viewpoint):\n            ref_view = int(f.readline().rstrip())\n            src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n            if len(src_views) > 0:\n                data.append((ref_view, src_views))\n    return data\n\n\n# project the reference point cloud into the source view, then project back\ndef reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    ## step1. project reference pixels to the source view\n    # reference view x, y\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])\n    # reference 3D space\n    xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),\n                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))\n    # source 3D space\n    xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),\n                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]\n    # source view x, y\n    K_xyz_src = np.matmul(intrinsics_src, xyz_src)\n    xy_src = K_xyz_src[:2] / K_xyz_src[2:3]\n\n    ## step2. reproject the source view points with source view depth estimation\n    # find the depth estimation of the source view\n    x_src = xy_src[0].reshape([height, width]).astype(np.float32)\n    y_src = xy_src[1].reshape([height, width]).astype(np.float32)\n    sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)\n    # mask = sampled_depth_src > 0\n\n    # source 3D space\n    # NOTE that we should use sampled source-view depth_here to project back\n    xyz_src = np.matmul(np.linalg.inv(intrinsics_src),\n                        np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))\n    # reference 3D space\n    xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),\n                                np.vstack((xyz_src, np.ones_like(x_ref))))[:3]\n    # source view x, y, depth\n    depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)\n    K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)\n    xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]\n    x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)\n    y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)\n\n    return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src\n\n\ndef check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref,\n                                                     depth_src, intrinsics_src, extrinsics_src)\n    # check |p_reproj-p_1| < 1\n    dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)\n\n    # check |d_reproj-d_1| / d_1 < 0.01\n    depth_diff = np.abs(depth_reprojected - depth_ref)\n    relative_depth_diff = depth_diff / depth_ref\n\n    mask = np.logical_and(dist < 1, relative_depth_diff < 0.01)\n    depth_reprojected[~mask] = 0\n\n    return mask, depth_reprojected, x2d_src, y2d_src\n\n\ndef filter_depth(pair_folder, scan_folder, out_folder, plyfilename):\n    # the pair file\n    pair_file = os.path.join(pair_folder, \"pair.txt\")\n    # for the final point cloud\n    vertexs = []\n    vertex_colors = []\n\n    pair_data = read_pair_file(pair_file)\n\n    # for each reference view and the corresponding source views\n    for ref_view, src_views in pair_data:\n        # src_views = src_views[:args.num_view]\n        # load the camera parameters\n        ref_intrinsics, ref_extrinsics = read_camera_parameters(\n            os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)))\n        # load the reference image\n        ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)))\n        # load the estimated depth of the reference view\n        ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]\n        # load the photometric mask of the reference view\n        confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0]\n        photo_mask = confidence > args.conf\n\n        all_srcview_depth_ests = []\n        all_srcview_x = []\n        all_srcview_y = []\n        all_srcview_geomask = []\n\n        # compute the geometric mask\n        geo_mask_sum = 0\n        for src_view in src_views:\n            # camera parameters of the source view\n            src_intrinsics, src_extrinsics = read_camera_parameters(\n                os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view)))\n            # the estimated depth of the source view\n            src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0]\n\n            geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics,\n                                                                      src_depth_est,\n                                                                      src_intrinsics, src_extrinsics)\n            geo_mask_sum += geo_mask.astype(np.int32)\n            all_srcview_depth_ests.append(depth_reprojected)\n            all_srcview_x.append(x2d_src)\n            all_srcview_y.append(y2d_src)\n            all_srcview_geomask.append(geo_mask)\n\n        depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)\n        # at least 3 source views matched\n        geo_mask = geo_mask_sum >= args.thres_view\n        final_mask = np.logical_and(photo_mask, geo_mask)\n\n        os.makedirs(os.path.join(out_folder, \"mask\"), exist_ok=True)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_photo.png\".format(ref_view)), photo_mask)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_geo.png\".format(ref_view)), geo_mask)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_final.png\".format(ref_view)), final_mask)\n\n        logger.info(\"processing {}, ref-view{:0>2}, photo/geo/final-mask:{:.3f}/{:.3f}/{:.3f}\".format(scan_folder, ref_view,\n                                                                                    photo_mask.mean(),\n                                                                                    geo_mask.mean(), final_mask.mean()))\n\n        height, width = depth_est_averaged.shape[:2]\n        x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))\n        # valid_points = np.logical_and(final_mask, ~used_mask[ref_view])\n        valid_points = final_mask\n        logger.info(\"valid_points: {}\".format(valid_points.mean()))\n        x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points]\n        #color = ref_img[1:-16:4, 1::4, :][valid_points]  # hardcoded for DTU dataset\n        color = ref_img[valid_points]\n\n        xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),\n                            np.vstack((x, y, np.ones_like(x))) * depth)\n        xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),\n                              np.vstack((xyz_ref, np.ones_like(x))))[:3]\n        vertexs.append(xyz_world.transpose((1, 0)))\n        vertex_colors.append((color * 255).astype(np.uint8))\n\n    vertexs = np.concatenate(vertexs, axis=0)\n    vertex_colors = np.concatenate(vertex_colors, axis=0)\n    vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n\n    vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)\n    for prop in vertexs.dtype.names:\n        vertex_all[prop] = vertexs[prop]\n    for prop in vertex_colors.dtype.names:\n        vertex_all[prop] = vertex_colors[prop]\n\n    el = PlyElement.describe(vertex_all, 'vertex')\n    PlyData([el]).write(plyfilename)\n    logger.info(\"saving the final model to \" + plyfilename)\n\n\ndef init_worker():\n    '''\n    Catch Ctrl+C signal to termiante workers\n    '''\n    signal.signal(signal.SIGINT, signal.SIG_IGN)\n\n\ndef pcd_filter_worker(scan):\n    scan_id = int(scan[4:])\n    save_name = 'mvsnet{:0>3}.ply'.format(scan_id)\n\n    pair_folder = os.path.join(args.testpath, \"Cameras\")\n    scan_folder = os.path.join(args.outdir, scan)\n    out_folder = os.path.join(args.outdir, scan)\n    filter_depth(pair_folder, scan_folder, out_folder, os.path.join(args.plydir, save_name))\n\n\ndef pcd_filter(testlist, number_worker):\n\n    partial_func = partial(pcd_filter_worker)\n\n    p = Pool(number_worker, init_worker)\n    try:\n        p.map(partial_func, testlist)\n    except KeyboardInterrupt:\n        logger.info(\"....\\nCaught KeyboardInterrupt, terminating workers\")\n        p.terminate()\n    else:\n        p.close()\n    p.join()\n\n\ndef initLogger():\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time()))\n    if not os.path.isdir(args.logdir):\n        os.mkdir(args.logdir)\n    logfile = os.path.join(args.logdir, 'fusion-' + curTime + '.log')\n    formatter = logging.Formatter(\"%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s\")\n    if not args.nolog:\n        fileHandler = logging.FileHandler(logfile, mode='a')\n        fileHandler.setFormatter(formatter)\n        logger.addHandler(fileHandler)\n    consoleHandler = logging.StreamHandler(sys.stdout)\n    consoleHandler.setFormatter(formatter)\n    logger.addHandler(consoleHandler)\n    logger.info(\"Logger initialized.\")\n    logger.info(\"Writing logs to file: {}\".format(logfile))\n    logger.info(\"Current time: {}\".format(curTime))\n\n    return logger\n\n\nif __name__ == '__main__':\n\n    logger = initLogger()\n\n    if not os.path.isdir(args.plydir):\n        os.mkdir(args.plydir)\n\n    with open(args.testlist) as f:\n        content = f.readlines()\n        testlist = [line.rstrip() for line in content]\n\n    pcd_filter(testlist, args.num_worker)"
  },
  {
    "path": "fusions/tnt/dypcd.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Point cloud fusion strategy for Tanks and Temples dataset: DYnamic PCD.\n#     Refer to: https://github.com/yhw-yhw/D2HC-RMVSNet/blob/master/fusion.py\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os\nimport cv2\nimport signal\nimport numpy as np\nfrom PIL import Image\nfrom functools import partial\nfrom multiprocessing import Pool\nfrom plyfile import PlyData, PlyElement\nimport argparse\nimport re, json\n\nfrom sklearn.preprocessing import scale\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--root_dir\", type=str, default=\"[/path/to/]tankandtemples/\")\nparser.add_argument('--out_dir', type=str, default='outputs/[exp_name]')\nparser.add_argument('--ply_path', type=str, default='outputs/[exp_name]/dypcd_fusion_plys')\n\nparser.add_argument('--split', type=str, default='intermediate', choices=['intermediate', 'advanced'])\nparser.add_argument('--list_file', type=str, default='datasets/lists/tnt/intermediate.txt')\nparser.add_argument('--num_workers', type=int, default=1)\nparser.add_argument('--single_processor', action='store_true')\n\nparser.add_argument('--rescale', action='store_true')\nparser.add_argument('--max_w', type=int)\nparser.add_argument('--max_h', type=int)\nparser.add_argument('--cam_mode', type=str, default='origin', choices=['origin', 'short_range'])\nparser.add_argument('--img_mode', type=str, default='resize', choices=['resize', 'crop'])\n\nparser.add_argument('--dist_base', type=float, default=1 / 4)\nparser.add_argument('--rel_diff_base', type=float, default=1 / 1300)\n\nargs = parser.parse_args()\n\n\ntnt_fusion_exps = [\n    {\n        \"ply_path\": \"dypcd_fusion_plys_mean\",\n        \"param_strategy\": \"mean\",\n    },\n    {\n        \"ply_path\": \"dypcd_fusion_plys\",\n        \"param_strategy\": \"hyper_param\",\n        \"hyper_param_table\": {    # -1 -> mean()\n            'Family': 0.6,\n            'Francis': 0.6,\n            'Horse': 0.2,\n            'Lighthouse': 0.7,\n            'M60': 0.6,\n            'Panther': 0.6,\n            'Playground': 0.7,\n            'Train': 0.6,\n\n            'Auditorium': 0.1,\n            'Ballroom': 0.4,\n            'Courtroom': 0.4,\n            'Museum': 0.5,\n            'Palace': 0.5,\n            'Temple': 0.4\n        }\n    },\n]\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\n# save a binary mask\ndef save_mask(filename, mask):\n    assert mask.dtype == np.bool\n    mask = mask.astype(np.uint8) * 255\n    Image.fromarray(mask).save(filename)\n\n\n# read an image\ndef read_img(filename):\n    img = Image.open(filename)\n    # scale 0~255 to 0~1\n    np_img = np.array(img, dtype=np.float32) / 255.\n    return np_img\n\n\n# read intrinsics and extrinsics\ndef read_camera_parameters(filename):\n    with open(filename) as f:\n        lines = f.readlines()\n        lines = [line.rstrip() for line in lines]\n    # extrinsics: line [1,5), 4x4 matrix\n    extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n    # intrinsics: line [7-10), 3x3 matrix\n    intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n    # TODO: assume the feature is 1/4 of the original image size\n    # intrinsics[:2, :] /= 4\n    return intrinsics, extrinsics\n\n\n# read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...]\ndef read_pair_file(filename):\n    data = []\n    with open(filename) as f:\n        num_viewpoint = int(f.readline())\n        # 49 viewpoints\n        for view_idx in range(num_viewpoint):\n            ref_view = int(f.readline().rstrip())\n            src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n            if len(src_views) > 0:\n                data.append((ref_view, src_views))\n    return data\n\n\n# project the reference point cloud into the source view, then project back\ndef reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    ## step1. project reference pixels to the source view\n    # reference view x, y\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])\n    # reference 3D space\n    xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),\n                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))\n    # source 3D space\n    xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),\n                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]\n    # source view x, y\n    K_xyz_src = np.matmul(intrinsics_src, xyz_src)\n    xy_src = K_xyz_src[:2] / K_xyz_src[2:3]\n\n    ## step2. reproject the source view points with source view depth estimation\n    # find the depth estimation of the source view\n    x_src = xy_src[0].reshape([height, width]).astype(np.float32)\n    y_src = xy_src[1].reshape([height, width]).astype(np.float32)\n    sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)\n    # mask = sampled_depth_src > 0\n\n    # source 3D space\n    # NOTE that we should use sampled source-view depth_here to project back\n    xyz_src = np.matmul(np.linalg.inv(intrinsics_src),\n                        np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))\n    # reference 3D space\n    xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),\n                                np.vstack((xyz_src, np.ones_like(x_ref))))[:3]\n    # source view x, y, depth\n    depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)\n    K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)\n    K_xyz_reprojected[2:3][K_xyz_reprojected[2:3]==0] += 0.00001\n    xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]\n    x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)\n    y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)\n\n    return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src\n\n\ndef check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref,\n                                                                                                 depth_src, intrinsics_src, extrinsics_src)\n    # check |p_reproj-p_1| < 1\n    dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)\n\n    # check |d_reproj-d_1| / d_1 < 0.01\n    depth_diff = np.abs(depth_reprojected - depth_ref)\n    relative_depth_diff = depth_diff / depth_ref\n\n    mask = None\n    masks = []\n    for i in range(2, 11):\n        # mask = np.logical_and(dist < i / 4, relative_depth_diff < i / 1300)\n        mask = np.logical_and(dist < i * args.dist_base, relative_depth_diff < i * args.rel_diff_base)\n        masks.append(mask)\n    depth_reprojected[~mask] = 0\n\n    return masks, mask, depth_reprojected, x2d_src, y2d_src\n\n\ndef scale_input(intrinsics, img):\n    if args.img_mode == \"crop\":\n        intrinsics[1,2] = intrinsics[1,2] - 28  # 1080 -> 1024\n        img = img[28:1080-28, :, :]\n    elif args.img_mode == \"resize\": \n        height, width = img.shape[:2]\n        img = cv2.resize(img, (width, 1024))\n        scale_h = 1.0 * 1024 / height\n        intrinsics[1, :] *= scale_h\n\n    return intrinsics, img\n\n\ndef filter_depth(scene, root_dir, split, out_dir, plyfilename, fusion_exp):\n    # num_stage = len(args.ndepths)\n\n    # the pair file\n    pair_file = os.path.join(root_dir, split, scene, \"pair.txt\")\n    # for the final point cloud\n    vertexs = []\n    vertex_colors = []\n\n    pair_data = read_pair_file(pair_file)\n    nviews = len(pair_data)\n\n    # for each reference view and the corresponding source views\n    for ref_view, src_views in pair_data:\n        # src_views = src_views[:args.num_view]\n        # load the camera parameters\n        if args.cam_mode == 'short_range':\n            ref_intrinsics, ref_extrinsics = read_camera_parameters(\n                os.path.join(root_dir, split, scene, 'cams_{}/{:0>8}_cam.txt'.format(scene.lower(), ref_view)))\n        elif args.cam_mode == 'origin':\n            ref_intrinsics, ref_extrinsics = read_camera_parameters(\n                os.path.join(root_dir, split, scene, 'cams/{:0>8}_cam.txt'.format(ref_view)))\n\n        ref_img = read_img(os.path.join(root_dir, split, scene, 'images/{:0>8}.jpg'.format(ref_view)))\n        ref_depth_est = read_pfm(os.path.join(out_dir, scene, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]\n        confidence = read_pfm(os.path.join(out_dir, scene, 'confidence/{:0>8}.pfm'.format(ref_view)))[0]\n\n        if fusion_exp['param_strategy'] == 'mean':\n            if ref_view % 50 == 0: print(\"-- thresh: {}\".format(confidence.mean()))\n            photo_mask = confidence > confidence.mean()\n        elif fusion_exp['param_strategy'] == 'hyper_param':\n            conf_thresh = fusion_exp['hyper_param_table'][scene]\n            if conf_thresh == -1:\n                photo_mask = confidence > confidence.mean()\n                if ref_view % 50 == 0: print(\"-- thresh: mean() {}\".format(confidence.mean()))\n            else:\n                photo_mask = confidence > conf_thresh\n                if ref_view % 50 == 0: print(\"-- thresh: {}\".format(conf_thresh))\n            \n        \n        flag_img = ref_img\n        ref_intrinsics, _ = scale_input(ref_intrinsics, flag_img)\n\n        all_srcview_depth_ests = []\n        all_srcview_x = []\n        all_srcview_y = []\n        all_srcview_geomask = []\n\n        # compute the geometric mask\n        geo_mask_sum = 0\n        dy_range = len(src_views) + 1\n        geo_mask_sums = [0] * (dy_range - 2)\n        for src_view in src_views:\n            # camera parameters of the source view\n            if args.cam_mode == 'short_range':\n                src_intrinsics, src_extrinsics = read_camera_parameters(\n                    os.path.join(root_dir, split, scene, 'cams_{}/{:0>8}_cam.txt'.format(scene.lower(), src_view)))\n            elif args.cam_mode == 'origin':\n                src_intrinsics, src_extrinsics = read_camera_parameters(\n                    os.path.join(root_dir, split, scene, 'cams/{:0>8}_cam.txt'.format(src_view)))\n            # the estimated depth of the source view\n            src_depth_est = read_pfm(os.path.join(out_dir, scene, 'depth_est/{:0>8}.pfm'.format(src_view)))[0]\n\n            src_intrinsics, _ = scale_input(src_intrinsics, flag_img)\n                \n            masks, geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics,\n                                                                                               ref_extrinsics, src_depth_est,\n                                                                                               src_intrinsics, src_extrinsics)\n            geo_mask_sum += geo_mask.astype(np.int32)\n            for i in range(2, dy_range):\n                geo_mask_sums[i - 2] += masks[i - 2].astype(np.int32)\n\n            all_srcview_depth_ests.append(depth_reprojected)\n            all_srcview_x.append(x2d_src)\n            all_srcview_y.append(y2d_src)\n            all_srcview_geomask.append(geo_mask)\n\n        depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)\n        # at least args.thres_view source views matched\n        geo_mask = geo_mask_sum >= dy_range\n        for i in range(2, dy_range):\n            geo_mask = np.logical_or(geo_mask, geo_mask_sums[i - 2] >= i)\n\n        final_mask = np.logical_and(photo_mask, geo_mask)\n\n        if ref_view < 3:\n            os.makedirs(os.path.join(out_dir, scene, \"mask\"), exist_ok=True)\n            save_mask(os.path.join(out_dir, scene, \"mask/{:0>8}_photo.png\".format(ref_view)), photo_mask)\n            save_mask(os.path.join(out_dir, scene, \"mask/{:0>8}_geo.png\".format(ref_view)), geo_mask)\n            save_mask(os.path.join(out_dir, scene, \"mask/{:0>8}_final.png\".format(ref_view)), final_mask)\n\n        print(\"processing {}, ref-view{:0>2}, photo/geo/final-mask:{:.3f}/{:.3f}/{:.3f}\".format(os.path.join(out_dir, scene), ref_view,\n                                                                                    photo_mask.mean(),\n                                                                                    geo_mask.mean(), final_mask.mean()))\n        \n        height, width = depth_est_averaged.shape[:2]\n        x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))\n        # valid_points = np.logical_and(final_mask, ~used_mask[ref_view])\n        valid_points = final_mask\n        print(\"valid_points {:.3f}\".format(valid_points.mean()))\n        x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points]\n \n        # color = ref_img[:-24, :, :][valid_points]\n        color = ref_img[28:1080-28, :, :][valid_points]\n\n        xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),\n                            np.vstack((x, y, np.ones_like(x))) * depth)\n        xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),\n                              np.vstack((xyz_ref, np.ones_like(x))))[:3]\n        vertexs.append(xyz_world.transpose((1, 0)))\n        vertex_colors.append((color * 255).astype(np.uint8))\n\n    vertexs = np.concatenate(vertexs, axis=0)\n    vertex_colors = np.concatenate(vertex_colors, axis=0)\n    vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n\n    vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)\n    for prop in vertexs.dtype.names:\n        vertex_all[prop] = vertexs[prop]\n    for prop in vertex_colors.dtype.names:\n        vertex_all[prop] = vertex_colors[prop]\n\n    el = PlyElement.describe(vertex_all, 'vertex')\n    PlyData([el]).write(plyfilename)\n    print(\"saving the final model to\", plyfilename)\n\n\ndef dypcd_filter_worker(scene):\n    save_name = '{}.ply'.format(scene)\n\n    filter_depth(scene, args.root_dir, args.split, args.out_dir, os.path.join(args.out_dir, fusion_exp['ply_path'], save_name), fusion_exp)\n\n\ndef init_worker():\n    signal.signal(signal.SIGINT, signal.SIG_IGN)\n\n\nif __name__ == '__main__':\n    \n    with open(os.path.join(args.list_file)) as f:\n        testlist = [line.rstrip() for line in f.readlines()]\n\n    for fusion_exp in tnt_fusion_exps:\n\n        if not os.path.isdir(os.path.join(args.out_dir, fusion_exp['ply_path'])):\n            os.mkdir(os.path.join(args.out_dir, fusion_exp['ply_path']))\n        \n\n        if args.single_processor:\n            for scene in testlist:\n                save_name = '{}.ply'.format(scene)\n                filter_depth(scene, args.root_dir, args.split, args.out_dir, os.path.join(args.out_dir, fusion_exp['ply_path'], save_name), fusion_exp)\n\n        else:\n            partial_func = partial(dypcd_filter_worker)\n            p = Pool(args.num_workers, init_worker)\n            try:\n                p.map(partial_func, testlist)\n            except KeyboardInterrupt:\n                print(\"....\\nCaught KeyboardInterrupt, terminating workers\")\n                p.terminate()\n            else:\n                p.close()\n            p.join()"
  },
  {
    "path": "models/__init__.py",
    "content": "from models.geomvsnet import GeoMVSNet\nfrom models.loss import geomvsnet_loss"
  },
  {
    "path": "models/filter.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Basic implementation of Frequency Domain Filtering strategy (Sec 3.2 in the paper).\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport torch\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\ndef frequency_domain_filter(depth, rho_ratio):\n    \"\"\"\n    large rho_ratio -> more information filtered\n    \"\"\"\n    f = torch.fft.fft2(depth)\n    fshift = torch.fft.fftshift(f)\n\n    b, h, w = depth.shape\n    k_h, k_w = h/rho_ratio, w/rho_ratio\n\n    fshift[:,:int(h/2-k_h/2),:] = 0\n    fshift[:,int(h/2+k_h/2):,:] = 0\n    fshift[:,:,:int(w/2-k_w/2)] = 0\n    fshift[:,:,int(w/2+k_w/2):] = 0\n\n    ishift = torch.fft.ifftshift(fshift)\n    idepth = torch.fft.ifft2(ishift)\n    depth_filtered = torch.abs(idepth)\n\n    return depth_filtered\n\n\ndef visual_fft_fig(fshift):\n    fft_fig = torch.abs(20 * torch.log(fshift))\n    plt.figure(figsize=(10, 10))\n    plt.subplot(121)\n    plt.imshow(fft_fig[0,:,:], cmap = 'gray')"
  },
  {
    "path": "models/geometry.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Geometric Prior Guided Feature Fusion & Probability Volume Geometry Embedding (Sec 3.1 in the paper).\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom models.submodules import ConvBnReLU3D\n\n\nclass GeoFeatureFusion(nn.Module):\n    def __init__(self, convolutional_layer_encoding=\"z\", mask_type=\"basic\", add_origin_feat_flag=True):\n        super(GeoFeatureFusion, self).__init__()        \n\n        self.convolutional_layer_encoding = convolutional_layer_encoding    # std / uv / z / xyz\n        self.mask_type = mask_type  # basic / mean\n        self.add_origin_feat_flag = add_origin_feat_flag    # True / False\n\n        if self.convolutional_layer_encoding == \"std\":\n            self.geoplanes = 0\n        elif self.convolutional_layer_encoding == \"uv\":\n            self.geoplanes = 2\n        elif self.convolutional_layer_encoding == \"z\":\n            self.geoplanes = 1\n        elif self.convolutional_layer_encoding == \"xyz\":\n            self.geoplanes = 3\n            self.geofeature = GeometryFeature()\n\n        # rgb encoder\n        self.rgb_conv_init = convbnrelu(in_channels=4, out_channels=8, kernel_size=5, stride=1, padding=2)\n\n        self.rgb_encoder_layer1 = BasicBlockGeo(inplanes=8, planes=16, stride=2, geoplanes=self.geoplanes)\n        self.rgb_encoder_layer2 = BasicBlockGeo(inplanes=16, planes=32, stride=1, geoplanes=self.geoplanes)\n        self.rgb_encoder_layer3 = BasicBlockGeo(inplanes=32, planes=64, stride=2, geoplanes=self.geoplanes)\n        self.rgb_encoder_layer4 = BasicBlockGeo(inplanes=64, planes=128, stride=1, geoplanes=self.geoplanes)\n        self.rgb_encoder_layer5 = BasicBlockGeo(inplanes=128, planes=256, stride=2, geoplanes=self.geoplanes)\n\n        self.rgb_decoder_layer4 = deconvbnrelu(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.rgb_decoder_layer2 = deconvbnrelu(in_channels=128, out_channels=32, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.rgb_decoder_layer0 = deconvbnrelu(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1, output_padding=0)\n        self.rgb_decoder_layer= deconvbnrelu(in_channels=16, out_channels=8, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.rgb_decoder_output = deconvbnrelu(in_channels=8, out_channels=2, kernel_size=3, stride=1, padding=1, output_padding=0)\n\n\n        # depth encoder\n        self.depth_conv_init = convbnrelu(in_channels=2, out_channels=8, kernel_size=5, stride=1, padding=2)\n\n        self.depth_layer1 = BasicBlockGeo(inplanes=8, planes=16, stride=2, geoplanes=self.geoplanes)\n        self.depth_layer2 = BasicBlockGeo(inplanes=16, planes=32, stride=1, geoplanes=self.geoplanes)\n        self.depth_layer3 = BasicBlockGeo(inplanes=64, planes=64, stride=2, geoplanes=self.geoplanes)\n        self.depth_layer4 = BasicBlockGeo(inplanes=64, planes=128, stride=1, geoplanes=self.geoplanes)\n        self.depth_layer5 = BasicBlockGeo(inplanes=256, planes=256, stride=2, geoplanes=self.geoplanes)\n\n        self.decoder_layer3 = deconvbnrelu(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.decoder_layer4 = deconvbnrelu(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, output_padding=0)\n        self.decoder_layer5 = deconvbnrelu(in_channels=64, out_channels=32, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.decoder_layer6 = deconvbnrelu(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1, output_padding=0)\n        self.decoder_layer7 = deconvbnrelu(in_channels=16, out_channels=8, kernel_size=5, stride=2, padding=2, output_padding=1)\n\n\n        # output\n        self.rgbdepth_decoder_stage1 = deconvbnrelu(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.rgbdepth_decoder_stage2 = deconvbnrelu(in_channels=16, out_channels=16, kernel_size=5, stride=2, padding=2, output_padding=1)\n        self.rgbdepth_decoder_stage3 = deconvbnrelu(in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1, output_padding=0)\n\n        self.final_decoder_stage1 = deconvbnrelu(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0)\n        self.final_decoder_stage2 = deconvbnrelu(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, output_padding=0)\n        self.final_decoder_stage3 = deconvbnrelu(in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1, output_padding=0)\n\n\n        self.softmax = nn.Softmax(dim=1)\n        self.pooling = nn.AvgPool2d(kernel_size=2)\n        self.sparsepooling = SparseDownSampleClose(stride=2)\n\n        weights_init(self)\n\n\n    def forward(self, rgb, depth, confidence, depth_values, stage_idx, origin_feat, intrinsics_matrices_stage):\n\n        rgb = rgb\n        depth_min, depth_max = depth_values[:,0,None,None,None], depth_values[:,-1,None,None,None]\n        d = (depth - depth_min) / (depth_max - depth_min)\n\n        if self.mask_type == \"basic\":\n            valid_mask = torch.where(d>0, torch.full_like(d, 1.0), torch.full_like(d, 0.0))\n        elif self.mask_type == \"mean\":\n            valid_mask = torch.where(torch.logical_and(d>0, confidence>confidence.mean()), torch.full_like(d, 1.0), torch.full_like(d, 0.0))\n\n\n        # pre-data preparation\n        if self.convolutional_layer_encoding in [\"uv\", \"xyz\"]:\n            B, _, W, H = rgb.shape\n            position = AddCoordsNp(H, W)\n            position = position.call()\n            position = torch.from_numpy(position).to(rgb.device).repeat(B, 1, 1, 1).transpose(-1, 1)\n            unorm = position[:, 0:1, :, :]\n            vnorm = position[:, 1:2, :, :]\n\n            vnorm_s2 = self.pooling(vnorm)\n            vnorm_s3 = self.pooling(vnorm_s2)\n            vnorm_s4 = self.pooling(vnorm_s3)\n\n            unorm_s2 = self.pooling(unorm)\n            unorm_s3 = self.pooling(unorm_s2)\n            unorm_s4 = self.pooling(unorm_s3)\n\n        if self.convolutional_layer_encoding in [\"z\", \"xyz\"]:\n            d_s2, vm_s2 = self.sparsepooling(d, valid_mask)\n            d_s3, vm_s3 = self.sparsepooling(d_s2, vm_s2)\n            d_s4, vm_s4 = self.sparsepooling(d_s3, vm_s3)\n        \n        if self.convolutional_layer_encoding == \"xyz\":\n            K = intrinsics_matrices_stage\n            f352 = K[:, 1, 1]\n            f352 = f352.unsqueeze(1)\n            f352 = f352.unsqueeze(2)\n            f352 = f352.unsqueeze(3)\n            c352 = K[:, 1, 2]\n            c352 = c352.unsqueeze(1)\n            c352 = c352.unsqueeze(2)\n            c352 = c352.unsqueeze(3)\n            f1216 = K[:, 0, 0]\n            f1216 = f1216.unsqueeze(1)\n            f1216 = f1216.unsqueeze(2)\n            f1216 = f1216.unsqueeze(3)\n            c1216 = K[:, 0, 2]\n            c1216 = c1216.unsqueeze(1)\n            c1216 = c1216.unsqueeze(2)\n            c1216 = c1216.unsqueeze(3)\n\n\n        # geometric info\n        if self.convolutional_layer_encoding == \"std\":\n            geo_s1 = None\n            geo_s2 = None\n            geo_s3 = None\n            geo_s4 = None\n        elif self.convolutional_layer_encoding == \"uv\":\n            geo_s1 = torch.cat((vnorm, unorm), dim=1)\n            geo_s2 = torch.cat((vnorm_s2, unorm_s2), dim=1)\n            geo_s3 = torch.cat((vnorm_s3, unorm_s3), dim=1)\n            geo_s4 = torch.cat((vnorm_s4, unorm_s4), dim=1)\n        elif self.convolutional_layer_encoding == \"z\":\n            geo_s1 = d\n            geo_s2 = d_s2\n            geo_s3 = d_s3\n            geo_s4 = d_s4\n        elif self.convolutional_layer_encoding == \"xyz\":\n            geo_s1 = self.geofeature(d, vnorm, unorm, H, W, c352, c1216, f352, f1216)\n            geo_s2 = self.geofeature(d_s2, vnorm_s2, unorm_s2, H / 2, W / 2, c352, c1216, f352, f1216)\n            geo_s3 = self.geofeature(d_s3, vnorm_s3, unorm_s3, H / 4, W / 4, c352, c1216, f352, f1216)\n            geo_s4 = self.geofeature(d_s4, vnorm_s4, unorm_s4, H / 8, W / 8, c352, c1216, f352, f1216)\n\n        # -----------------------------------------------------------------------------------------\n\n        # 128*160 -> 256*320 -> 512*640\n        rgb_feature = self.rgb_conv_init(torch.cat((rgb, d), dim=1))            # b 8 h w\n        rgb_feature1 = self.rgb_encoder_layer1(rgb_feature, geo_s1, geo_s2)     # b 16 h/2 w/2\n        rgb_feature2 = self.rgb_encoder_layer2(rgb_feature1, geo_s2, geo_s2)    # b 32 h/2 w/2\n        rgb_feature3 = self.rgb_encoder_layer3(rgb_feature2, geo_s2, geo_s3)    # b 64 h/4 w/4\n        rgb_feature4 = self.rgb_encoder_layer4(rgb_feature3, geo_s3, geo_s3)    # b 128 h/4 w/4\n        rgb_feature5 = self.rgb_encoder_layer5(rgb_feature4, geo_s3, geo_s4)    # b 256 h/8 w/8\n\n        rgb_feature_decoder4 = self.rgb_decoder_layer4(rgb_feature5)\n        rgb_feature4_plus = rgb_feature_decoder4 + rgb_feature4         # b 128 h/4 w/4\n\n        rgb_feature_decoder2 = self.rgb_decoder_layer2(rgb_feature4_plus)\n        rgb_feature2_plus = rgb_feature_decoder2 + rgb_feature2         # b 32 h/2 w/2\n\n        rgb_feature_decoder0 = self.rgb_decoder_layer0(rgb_feature2_plus)\n        rgb_feature0_plus = rgb_feature_decoder0 + rgb_feature1          # b 16 h/2 w/2\n\n        rgb_feature_decoder = self.rgb_decoder_layer(rgb_feature0_plus)\n        rgb_feature_plus = rgb_feature_decoder + rgb_feature            # b 8 h w\n\n        rgb_output = self.rgb_decoder_output(rgb_feature_plus)          # b 2 h w\n\n        rgb_depth = rgb_output[:, 0:1, :, :]\n        rgb_conf = rgb_output[:, 1:2, :, :]\n\n        # -----------------------------------------------------------------------------------------\n\n        sparsed_feature = self.depth_conv_init(torch.cat((d, rgb_depth), dim=1))    # b 8 h w\n        sparsed_feature1 = self.depth_layer1(sparsed_feature, geo_s1, geo_s2)       # b 16 h/2 w/2\n        sparsed_feature2 = self.depth_layer2(sparsed_feature1, geo_s2, geo_s2)      # b 32 h/2 w/2\n\n        sparsed_feature2_plus = torch.cat([rgb_feature2_plus, sparsed_feature2], 1)\n        sparsed_feature3 = self.depth_layer3(sparsed_feature2_plus, geo_s2, geo_s3) # b 64 h/4 w/4\n        sparsed_feature4 = self.depth_layer4(sparsed_feature3, geo_s3, geo_s3)      # b 128 h/4 w/4\n\n        sparsed_feature4_plus = torch.cat([rgb_feature4_plus, sparsed_feature4], 1)\n        sparsed_feature5 = self.depth_layer5(sparsed_feature4_plus, geo_s3, geo_s4) # b 256 h/8 w/8\n\n        # -----------------------------------------------------------------------------------------\n\n        fusion3 = rgb_feature5 + sparsed_feature5\n        decoder_feature3 = self.decoder_layer3(fusion3) # b 128 h/4 w/4\n\n        fusion4 = sparsed_feature4 + decoder_feature3\n        decoder_feature4 = self.decoder_layer4(fusion4) # b 64 h/4 w/4\n\n        if stage_idx >= 1: \n            decoder_feature5 = self.decoder_layer5(decoder_feature4)\n            fusion5 = sparsed_feature2 + decoder_feature5   # b 32 h/2 w/2\n            if stage_idx == 1:\n                rgbdepth_feature = self.rgbdepth_decoder_stage1(fusion5)\n                if self.add_origin_feat_flag:\n                    final_feature = self.final_decoder_stage1(rgbdepth_feature + origin_feat)\n                else:\n                    final_feature = self.final_decoder_stage1(rgbdepth_feature)\n\n        if stage_idx >= 2:\n            decoder_feature6 = self.decoder_layer6(decoder_feature5)\n            fusion6 = sparsed_feature1 + decoder_feature6   # b 16 h/2 w/2\n            if stage_idx == 2:\n                rgbdepth_feature = self.rgbdepth_decoder_stage2(fusion6)\n                if self.add_origin_feat_flag:\n                    final_feature = self.final_decoder_stage2(rgbdepth_feature + origin_feat)\n                else:\n                    final_feature = self.final_decoder_stage2(rgbdepth_feature)\n\n        if stage_idx >= 3:\n            decoder_feature7 = self.decoder_layer7(decoder_feature6)\n            fusion7 = sparsed_feature + decoder_feature7    # b 8 h w\n            if stage_idx == 3:\n                rgbdepth_feature = self.rgbdepth_decoder_stage3(fusion7)\n                if self.add_origin_feat_flag:\n                    final_feature = self.final_decoder_stage3(rgbdepth_feature + origin_feat)\n                else:\n                    final_feature = self.final_decoder_stage3(rgbdepth_feature)\n\n\n        return final_feature\n\n\nclass GeoRegNet2d(nn.Module):\n    def __init__(self, input_channel=128, base_channel=32, convolutional_layer_encoding=\"std\"):\n        super(GeoRegNet2d, self).__init__()\n\n        self.convolutional_layer_encoding = convolutional_layer_encoding    # std / uv / z / xyz\n        self.mask_type = \"basic\"    # basic / mean\n\n        if self.convolutional_layer_encoding == \"std\":\n            self.geoplanes = 0\n        elif self.convolutional_layer_encoding == \"z\":\n            self.geoplanes = 1\n\n        self.conv_init = ConvBnReLU3D(input_channel, out_channels=8, kernel_size=(1,3,3), pad=(0,1,1))\n        self.encoder_layer1 = Reg_BasicBlockGeo(inplanes=8, planes=16, kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), geoplanes=self.geoplanes)\n        self.encoder_layer2 = Reg_BasicBlockGeo(inplanes=16, planes=32,  kernel_size=(1,3,3), stride=1, padding=(0,1,1), geoplanes=self.geoplanes)\n        self.encoder_layer3 = Reg_BasicBlockGeo(inplanes=32, planes=64, kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), geoplanes=self.geoplanes)\n        self.encoder_layer4 = Reg_BasicBlockGeo(inplanes=64, planes=128,  kernel_size=(1,3,3), stride=1, padding=(0,1,1), geoplanes=self.geoplanes)\n        self.encoder_layer5 = Reg_BasicBlockGeo(inplanes=128, planes=256, kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), geoplanes=self.geoplanes)\n\n        self.decoder_layer4 = reg_deconvbnrelu(in_channels=256, out_channels=128, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,2,2), output_padding=(0,1,1))\n        self.decoder_layer3 = reg_deconvbnrelu(in_channels=128, out_channels=64, kernel_size=(1,3,3), stride=1, padding=(0,1,1), output_padding=0)\n        self.decoder_layer2 = reg_deconvbnrelu(in_channels=64, out_channels=32, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,2,2), output_padding=(0,1,1))\n        self.decoder_layer1 = reg_deconvbnrelu(in_channels=32, out_channels=16, kernel_size=(1,3,3), stride=1, padding=(0,1,1), output_padding=0)\n        self.decoder_layer = reg_deconvbnrelu(in_channels=16, out_channels=8, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,2,2), output_padding=(0,1,1))\n\n        self.prob = reg_deconvbnrelu(in_channels=8, out_channels=1, kernel_size=(1,3,3), stride=1, padding=(0,1,1), output_padding=0)\n\n        self.depthpooling = nn.MaxPool3d((2,1,1),(2,1,1))\n        self.basicpooling = nn.MaxPool3d((1,2,2), (1,2,2))\n\n        weights_init(self)\n\n\n    def forward(self, x, stage_idx, geo_reg_data=None):\n\n        B, C, D, W, H = x.shape\n\n        if stage_idx >= 1 and self.convolutional_layer_encoding == \"z\":\n            prob_volume = geo_reg_data[\"prob_volume_last\"].unsqueeze(1)  # B 1 D H W\n        else:\n            assert self.convolutional_layer_encoding == \"std\"\n\n\n        # geometric info\n        if self.convolutional_layer_encoding == \"std\":\n            geo_s1 = None\n            geo_s2 = None\n            geo_s3 = None\n            geo_s4 = None\n        elif self.convolutional_layer_encoding == \"z\":\n            if stage_idx == 2:\n                geo_s1 = self.depthpooling(prob_volume)\n            else:\n                geo_s1 = prob_volume   # B 1 D H W\n            geo_s2 = self.basicpooling(geo_s1)\n            geo_s3 = self.basicpooling(geo_s2)\n  \n        feature = self.conv_init(x)     # B 8 D H W\n        feature1 = self.encoder_layer1(feature, geo_s1, geo_s1)     # B  16 D H/2 W/2\n        feature2 = self.encoder_layer2(feature1, geo_s2, geo_s2)    # B  32 D H/2 W/2\n        feature3 = self.encoder_layer3(feature2, geo_s2, geo_s2)    # B  64 D H/4 W/4\n        feature4 = self.encoder_layer4(feature3, geo_s3, geo_s3)    # B 128 D H/4 W/4\n        feature5 = self.encoder_layer5(feature4, geo_s3, geo_s3)    # B 256 D H/8 W/8\n\n        feature_decoder4 = self.decoder_layer4(feature5)\n        feature4_plus = feature_decoder4 + feature4           # B 128 D H/4 W/4\n\n        feature_decoder3 = self.decoder_layer3(feature4_plus)\n        feature3_plus = feature_decoder3 + feature3           # B 64 D H/4 W/4\n\n        feature_decoder2 = self.decoder_layer2(feature3_plus)\n        feature2_plus = feature_decoder2 + feature2           # B 32 D H/2 W/2\n\n        feature_decoder1 = self.decoder_layer1(feature2_plus)\n        feature1_plus = feature_decoder1 + feature1           # B 16 D H/2 W/2\n\n        feature_decoder = self.decoder_layer(feature1_plus)\n        feature_plus = feature_decoder + feature            # B  8 D H W\n\n        x = self.prob(feature_plus)\n\n        return x.squeeze(1)\n\n\n# --------------------------------------------------------------\n\n\nclass BasicBlockGeo(nn.Module):\n    expansion = 1\n    __constants__ = ['downsample']\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None, geoplanes=3):\n        super(BasicBlockGeo, self).__init__()\n\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n\n        self.conv1 = conv3x3(inplanes + geoplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes+geoplanes, planes)\n        self.bn2 = norm_layer(planes)\n        if stride != 1 or inplanes != planes:\n            downsample = nn.Sequential(\n                conv1x1(inplanes+geoplanes, planes, stride),\n                norm_layer(planes),\n            )\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x, g1=None, g2=None):\n        identity = x\n        if g1 is not None:\n            x = torch.cat((x, g1), 1)\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        if g2 is not None:\n            out = torch.cat((g2,out), 1)\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass GeometryFeature(nn.Module):\n    def __init__(self):\n        super(GeometryFeature, self).__init__()\n\n    def forward(self, z, vnorm, unorm, h, w, ch, cw, fh, fw):\n        x = z*(0.5*h*(vnorm+1)-ch)/fh\n        y = z*(0.5*w*(unorm+1)-cw)/fw\n        return torch.cat((x, y, z),1)\n\n\nclass SparseDownSampleClose(nn.Module):\n    def __init__(self, stride):\n        super(SparseDownSampleClose, self).__init__()\n        self.pooling = nn.MaxPool2d(stride, stride)\n        self.large_number = 600\n    def forward(self, d, mask):\n        encode_d = - (1-mask)*self.large_number - d\n\n        d = - self.pooling(encode_d)\n        mask_result = self.pooling(mask)\n        d_result = d - (1-mask_result)*self.large_number\n\n        return d_result, mask_result\n\n\ndef convbnrelu(in_channels, out_channels, kernel_size=3,stride=1, padding=1):\n    return nn.Sequential(\n\t\tnn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),\n\t\tnn.BatchNorm2d(out_channels),\n\t\tnn.ReLU(inplace=True)\n\t)\n\n\ndef deconvbnrelu(in_channels, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1):\n    return nn.Sequential(\n\t\tnn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False),\n\t\tnn.BatchNorm2d(out_channels),\n\t\tnn.ReLU(inplace=True)\n\t)\n\n\ndef weights_init(m):\n    \"\"\"Initialize filters with Gaussian random weights\"\"\"\n    if isinstance(m, nn.Conv2d):\n        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n        m.weight.data.normal_(0, math.sqrt(2. / n))\n        if m.bias is not None:\n            m.bias.data.zero_()\n    elif isinstance(m, nn.ConvTranspose2d):\n        n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels\n        m.weight.data.normal_(0, math.sqrt(2. / n))\n        if m.bias is not None:\n            m.bias.data.zero_()\n    elif isinstance(m, nn.BatchNorm2d):\n        m.weight.data.fill_(1)\n        m.bias.data.zero_()\n\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, bias=False, padding=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    if padding >= 1:\n        padding = dilation\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=padding, groups=groups, bias=bias, dilation=dilation)\n\n\ndef conv1x1(in_planes, out_planes, stride=1, groups=1, bias=False):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=groups, bias=bias)\n\n\nclass AddCoordsNp():\n\t\"\"\"Add coords to a tensor\"\"\"\n\tdef __init__(self, x_dim=64, y_dim=64, with_r=False):\n\t\tself.x_dim = x_dim\n\t\tself.y_dim = y_dim\n\t\tself.with_r = with_r\n\n\tdef call(self):\n\t\t\"\"\"\n\t\tinput_tensor: (batch, x_dim, y_dim, c)\n\t\t\"\"\"\n\t\txx_ones = np.ones([self.x_dim], dtype=np.int32)\n\t\txx_ones = np.expand_dims(xx_ones, 1)\n\n\t\txx_range = np.expand_dims(np.arange(self.y_dim), 0)\n\n\t\txx_channel = np.matmul(xx_ones, xx_range)\n\t\txx_channel = np.expand_dims(xx_channel, -1)\n\n\t\tyy_ones = np.ones([self.y_dim], dtype=np.int32)\n\t\tyy_ones = np.expand_dims(yy_ones, 0)\n\n\t\tyy_range = np.expand_dims(np.arange(self.x_dim), 1)\n\n\t\tyy_channel = np.matmul(yy_range, yy_ones)\n\t\tyy_channel = np.expand_dims(yy_channel, -1)\n\n\t\txx_channel = xx_channel.astype('float32') / (self.y_dim - 1)\n\t\tyy_channel = yy_channel.astype('float32') / (self.x_dim - 1)\n\n\t\txx_channel = xx_channel*2 - 1\n\t\tyy_channel = yy_channel*2 - 1\n\n\t\tret = np.concatenate([xx_channel, yy_channel], axis=-1)\n\n\t\tif self.with_r:\n\t\t\trr = np.sqrt( np.square(xx_channel-0.5) + np.square(yy_channel-0.5))\n\t\t\tret = np.concatenate([ret, rr], axis=-1)\n\n\t\treturn ret\n\n\n# --------------------------------------------------------------\n\n\nclass Reg_BasicBlockGeo(nn.Module):\n\n    def __init__(self, inplanes, planes, kernel_size, stride, padding, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=nn.BatchNorm3d, geoplanes=3):\n        super(Reg_BasicBlockGeo, self).__init__()\n\n        self.conv1 = regconv3D(inplanes + geoplanes, planes, kernel_size=(1,3,3), stride=1, padding=(0,1,1))\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = regconv3D(planes+geoplanes, planes, kernel_size, stride, padding)\n        self.bn2 = norm_layer(planes)\n        if stride != 1 or inplanes != planes:\n            downsample = nn.Sequential(\n                regconv1x1(inplanes+geoplanes, planes, kernel_size, stride, padding),\n                norm_layer(planes),\n            )\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x, g1=None, g2=None):\n        identity = x\n        if g1 is not None:\n            x = torch.cat((x, g1), 1)\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        if g2 is not None:\n            out = torch.cat((g2,out), 1)\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\ndef regconv3D(in_planes, out_planes, kernel_size, stride, padding, groups=1, dilation=1, bias=False):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,\n                     padding=padding, groups=groups, bias=bias, dilation=dilation)\n\n\ndef regconv1x1(in_planes, out_planes, kernel_size, stride, padding, groups=1, bias=False):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)\n\n\ndef reg_deconvbnrelu(in_channels, out_channels, kernel_size, stride, padding, output_padding):\n    return nn.Sequential(\n\t\tnn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False),\n\t\tnn.BatchNorm3d(out_channels),\n\t\tnn.ReLU(inplace=True)\n\t)"
  },
  {
    "path": "models/geomvsnet.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Main network architecture for GeoMVSNet.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom models.submodules import homo_warping, init_inverse_range, schedule_inverse_range, FPN, Reg2d\nfrom models.geometry import GeoFeatureFusion, GeoRegNet2d\nfrom models.filter import frequency_domain_filter\n\n\nclass GeoMVSNet(nn.Module):\n    def __init__(self, levels, hypo_plane_num_stages, depth_interal_ratio_stages, \n                    feat_base_channel, reg_base_channel, group_cor_dim_stages):\n        super(GeoMVSNet, self).__init__()\n        \n        self.levels = levels\n        self.hypo_plane_num_stages = hypo_plane_num_stages\n        self.depth_interal_ratio_stages = depth_interal_ratio_stages\n\n        self.StageNet = StageNet()\n\n        # feature settings\n        self.FeatureNet = FPN(base_channels=feat_base_channel)\n        self.coarest_separate_flag = True\n        if self.coarest_separate_flag:\n            self.CoarestFeatureNet = FPN(base_channels=feat_base_channel)\n        self.GeoFeatureFusionNet = GeoFeatureFusion(\n            convolutional_layer_encoding=\"z\", mask_type=\"basic\", add_origin_feat_flag=True)\n\n        # cost regularization settings\n        self.RegNet_stages = nn.ModuleList()\n        self.group_cor_dim_stages = group_cor_dim_stages\n        self.geo_reg_flag = True\n        self.geo_reg_encodings = ['std', 'z', 'z', 'z']     # must use std in idx-0\n        for stage_idx in range(self.levels):\n            in_dim = group_cor_dim_stages[stage_idx]\n            if self.geo_reg_flag:\n                self.RegNet_stages.append(GeoRegNet2d(input_channel=in_dim, base_channel=reg_base_channel, convolutional_layer_encoding=self.geo_reg_encodings[stage_idx]))\n            else:\n                self.RegNet_stages.append(Reg2d(input_channel=in_dim, base_channel=reg_base_channel))\n\n        # frequency domain filter settings\n        self.curriculum_learning_rho_ratios = [9, 4, 2, 1]\n\n\n    def forward(self, imgs, proj_matrices, intrinsics_matrices, depth_values, filename=None):\n        \n        features = []\n        if self.coarest_separate_flag:\n            coarsest_features = []\n        for nview_idx in range(len(imgs)):\n            img = imgs[nview_idx]\n            features.append(self.FeatureNet(img))   # B C H W\n            if self.coarest_separate_flag:\n                coarsest_features.append(self.CoarestFeatureNet(img))\n        \n        # coarse-to-fine\n        outputs = {}\n        for stage_idx in range(self.levels):\n            stage_name = \"stage{}\".format(stage_idx + 1)\n            B, C, H, W = features[0][stage_name].shape\n            proj_matrices_stage = proj_matrices[stage_name]\n            intrinsics_matrices_stage = intrinsics_matrices[stage_name]\n\n            # @Note features\n            if stage_idx == 0:\n                if self.coarest_separate_flag:\n                    features_stage = [feat[stage_name] for feat in coarsest_features]\n                else:\n                    features_stage = [feat[stage_name] for feat in features]\n            elif stage_idx >= 1:\n                features_stage = [feat[stage_name] for feat in features]\n                \n                ref_img_stage = F.interpolate(imgs[0], size=None, scale_factor=1./2**(3-stage_idx), mode=\"bilinear\", align_corners=False)\n                depth_last = F.interpolate(depth_last.unsqueeze(1), size=None, scale_factor=2, mode=\"bilinear\", align_corners=False)\n                confidence_last = F.interpolate(confidence_last.unsqueeze(1), size=None, scale_factor=2, mode=\"bilinear\", align_corners=False)\n                \n                # reference feature\n                features_stage[0] = self.GeoFeatureFusionNet(\n                    ref_img_stage, depth_last, confidence_last, depth_values, \n                    stage_idx, features_stage[0], intrinsics_matrices_stage\n                )\n\n\n            # @Note depth hypos\n            if stage_idx == 0:\n                depth_hypo = init_inverse_range(depth_values, self.hypo_plane_num_stages[stage_idx], img[0].device, img[0].dtype, H, W)\n            else:\n                inverse_min_depth, inverse_max_depth = outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach()\n                depth_hypo = schedule_inverse_range(inverse_min_depth, inverse_max_depth, self.hypo_plane_num_stages[stage_idx], H, W)  # B D H W\n\n\n            # @Note cost regularization\n            geo_reg_data = {}\n            if self.geo_reg_flag:\n                geo_reg_data['depth_values'] = depth_values\n                if stage_idx >= 1 and self.geo_reg_encodings[stage_idx] == 'z':\n                    prob_volume_last = F.interpolate(prob_volume_last, size=None, scale_factor=2, mode=\"bilinear\", align_corners=False)\n                    geo_reg_data[\"prob_volume_last\"] = prob_volume_last\n\n            outputs_stage = self.StageNet(\n                stage_idx, features_stage, proj_matrices_stage, depth_hypo=depth_hypo, \n                regnet=self.RegNet_stages[stage_idx], group_cor_dim=self.group_cor_dim_stages[stage_idx], \n                depth_interal_ratio=self.depth_interal_ratio_stages[stage_idx], \n                geo_reg_data=geo_reg_data\n            )\n\n\n            # @Note frequency domain filter\n            depth_est = outputs_stage['depth']\n            depth_est_filtered = frequency_domain_filter(depth_est, rho_ratio=self.curriculum_learning_rho_ratios[stage_idx])\n            outputs_stage['depth_filtered'] = depth_est_filtered\n            depth_last = depth_est_filtered\n\n\n            confidence_last = outputs_stage['photometric_confidence']\n            prob_volume_last = outputs_stage['prob_volume']\n\n            outputs[stage_name] = outputs_stage\n            outputs.update(outputs_stage)\n\n        return outputs\n\n\nclass StageNet(nn.Module):\n    def __init__(self, attn_temp=2):\n        super(StageNet, self).__init__()\n        self.attn_temp = attn_temp\n\n    def forward(self, stage_idx, features, proj_matrices, depth_hypo, regnet, \n                    group_cor_dim, depth_interal_ratio, geo_reg_data=None):\n\n        # @Note step1: feature extraction\n        proj_matrices = torch.unbind(proj_matrices, 1)\n        ref_feature, src_features = features[0], features[1:]\n        ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]\n        B, D, H, W = depth_hypo.shape\n        C = ref_feature.shape[1]\n\n\n        # @Note step2: cost aggregation\n        ref_volume =  ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1)\n        cor_weight_sum = 1e-8\n        cor_feats = 0\n        for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)):\n            save_fn = None\n            src_proj_new = src_proj[:, 0].clone()\n            src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])\n            ref_proj_new = ref_proj[:, 0].clone()\n            ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])\n            warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo)  # B C D H W\n\n            warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)\n            ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)\n            cor_feat = (warped_src * ref_volume).mean(2)  # B G D H W\n            del warped_src, src_proj, src_fea\n\n            cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C)  # B D H W\n            cor_weight_sum += cor_weight  # B D H W\n            cor_feats += cor_weight.unsqueeze(1) * cor_feat  # B C D H W\n            del cor_weight, cor_feat\n\n        cost_volume = cor_feats / cor_weight_sum.unsqueeze(1)  # B C D H W\n        del cor_weight_sum, src_features\n        \n    \n        # @Note step3: cost regularization\n        if geo_reg_data == {}:\n            # basic\n            cost_reg = regnet(cost_volume)\n        else:\n            # probability volume geometry embedding\n            cost_reg = regnet(cost_volume, stage_idx, geo_reg_data)\n        del cost_volume\n        prob_volume = F.softmax(cost_reg, dim=1)  # B D H W\n\n\n        # @Note step4: depth regression\n        prob_max_indices = prob_volume.max(1, keepdim=True)[1]  # B 1 H W\n        depth = torch.gather(depth_hypo, 1, prob_max_indices).squeeze(1)  # B H W\n\n        with torch.no_grad():\n            photometric_confidence = prob_volume.max(1)[0]  # B H W\n            photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=1, mode='bilinear', align_corners=True).squeeze(1)\n        \n        last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:]\n        inverse_min_depth = 1/depth + depth_interal_ratio * last_depth_itv  # B H W\n        inverse_max_depth = 1/depth - depth_interal_ratio * last_depth_itv  # B H W\n\n\n        output_stage = {\n            \"depth\": depth,  \n            \"photometric_confidence\": photometric_confidence, \n            \"depth_hypo\": depth_hypo, \n            \"prob_volume\": prob_volume,\n            \"inverse_min_depth\": inverse_min_depth, \n            \"inverse_max_depth\": inverse_max_depth,\n        }\n        return output_stage"
  },
  {
    "path": "models/loss.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Loss Functions (Sec 3.4 in the paper).\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport torch\n\n\ndef geomvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):\n\n    stage_lw = kwargs.get(\"stage_lw\", [1, 1, 1, 1])\n    depth_values = kwargs.get(\"depth_values\")\n    depth_min, depth_max = depth_values[:,0], depth_values[:,-1]\n    \n    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n    pw_loss_stages = []\n    dds_loss_stages = []\n    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if \"stage\" in k]):\n        \n        depth = stage_inputs['depth_filtered']\n        prob_volume = stage_inputs['prob_volume']\n        depth_value = stage_inputs['depth_hypo']\n\n        depth_gt = depth_gt_ms[stage_key]\n        mask = mask_ms[stage_key] > 0.5\n\n\n        # pw loss\n        pw_loss = pixel_wise_loss(prob_volume, depth_gt, mask, depth_value)\n        pw_loss_stages.append(pw_loss)\n\n        # dds loss\n        dds_loss = depth_distribution_similarity_loss(depth, depth_gt, mask, depth_min, depth_max)\n        dds_loss_stages.append(dds_loss)\n        \n        # total loss\n        lam1, lam2 = 0.8, 0.2\n        total_loss = total_loss + stage_lw[stage_idx] * (lam1 * pw_loss + lam2 * dds_loss)\n  \n\n    depth_pred = stage_inputs['depth']\n    depth_gt = depth_gt_ms[stage_key]\n    epe = cal_metrics(depth_pred, depth_gt, mask, depth_min, depth_max)\n    \n    return total_loss, epe, pw_loss_stages, dds_loss_stages\n\n\n\ndef pixel_wise_loss(prob_volume, depth_gt, mask, depth_value):\n    mask_true = mask\n    valid_pixel_num = torch.sum(mask_true, dim=[1,2])+1e-12\n\n    shape = depth_gt.shape\n\n    depth_num = depth_value.shape[1]\n    depth_value_mat = depth_value\n\n    gt_index_image = torch.argmin(torch.abs(depth_value_mat-depth_gt.unsqueeze(1)), dim=1)\n\n    gt_index_image = torch.mul(mask_true, gt_index_image.type(torch.float))\n    gt_index_image = torch.round(gt_index_image).type(torch.long).unsqueeze(1)\n\n    gt_index_volume = torch.zeros(shape[0], depth_num, shape[1], shape[2]).type(mask_true.type()).scatter_(1, gt_index_image, 1)\n    cross_entropy_image = -torch.sum(gt_index_volume * torch.log(prob_volume+1e-12), dim=1).squeeze(1)\n    masked_cross_entropy_image = torch.mul(mask_true, cross_entropy_image)\n    masked_cross_entropy = torch.sum(masked_cross_entropy_image, dim=[1, 2])\n\n    masked_cross_entropy = torch.mean(masked_cross_entropy / valid_pixel_num)\n    \n    pw_loss = masked_cross_entropy\n    return pw_loss\n\n\ndef depth_distribution_similarity_loss(depth, depth_gt, mask, depth_min, depth_max):\n    depth_norm = depth * 128 / (depth_max - depth_min)[:,None,None]\n    depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None]\n\n    M_bins = 48\n    kl_min = torch.min(torch.min(depth_gt), depth.mean()-3.*depth.std())\n    kl_max = torch.max(torch.max(depth_gt), depth.mean()+3.*depth.std())\n    bins = torch.linspace(kl_min, kl_max, steps=M_bins)\n\n    kl_divs = []\n    for i in range(len(bins) - 1):\n        bin_mask = (depth_gt >= bins[i]) & (depth_gt < bins[i+1])\n        merged_mask = mask & bin_mask \n\n        if merged_mask.sum() > 0:\n            p = depth_norm[merged_mask]\n            q = depth_gt_norm[merged_mask]\n            kl_div = torch.nn.functional.kl_div(torch.log(p)-torch.log(q), p, reduction='batchmean')\n            kl_div = torch.log(kl_div)\n            kl_divs.append(kl_div)\n\n    dds_loss = sum(kl_divs)\n    return dds_loss\n\n\ndef cal_metrics(depth_pred, depth_gt, mask, depth_min, depth_max):\n    depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None]\n    depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None]\n\n    abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask])\n    epe = abs_err.mean()\n    err1= (abs_err<=1).float().mean()*100\n    err3 = (abs_err<=3).float().mean()*100\n    \n    return epe  # err1, err3"
  },
  {
    "path": "models/submodules.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Some sub-modules for the network.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FPN(nn.Module):\n    \"\"\"FPN aligncorners downsample 4x\"\"\"\n    def __init__(self, base_channels, gn=False):\n        super(FPN, self).__init__()\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv3 = nn.Sequential(\n            Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),\n        )\n\n        self.out_channels = [8 * base_channels]\n        final_chs = base_channels * 8\n\n        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)\n        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.out_channels.append(base_channels * 4)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n        conv3 = self.conv3(conv2)\n\n        intra_feat = conv3\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv2)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv1)\n        out3 = self.out3(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner3(conv0)\n        out4 = self.out4(intra_feat)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n        outputs[\"stage4\"] = out4\n\n        return outputs\n\n\nclass Reg2d(nn.Module):\n    def __init__(self, input_channel=128, base_channel=32):\n        super(Reg2d, self).__init__()\n        \n        self.conv0 = ConvBnReLU3D(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1))\n        self.conv1 = ConvBnReLU3D(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv2 = ConvBnReLU3D(base_channel*2, base_channel*2)\n\n        self.conv3 = ConvBnReLU3D(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv4 = ConvBnReLU3D(base_channel*4, base_channel*4)\n\n        self.conv5 = ConvBnReLU3D(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv6 = ConvBnReLU3D(base_channel*8, base_channel*8)\n\n        self.conv7 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel*4),\n            nn.ReLU(inplace=True))\n\n        self.conv9 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel*2),\n            nn.ReLU(inplace=True))\n\n        self.conv11 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel),\n            nn.ReLU(inplace=True))\n\n        self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv2 = self.conv2(self.conv1(conv0))\n        conv4 = self.conv4(self.conv3(conv2))\n        x = self.conv6(self.conv5(conv4))\n        x = conv4 + self.conv7(x)\n        x = conv2 + self.conv9(x)\n        x = conv0 + self.conv11(x)\n        x = self.prob(x)\n\n        return x.squeeze(1)\n\n\ndef homo_warping(src_fea, src_proj, ref_proj, depth_values):\n    # src_fea: [B, C, H, W]\n    # src_proj: [B, 4, 4]\n    # ref_proj: [B, 4, 4]\n    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]\n    # out: [B, C, Ndepth, H, W]\n    C = src_fea.shape[1]\n    Hs,Ws = src_fea.shape[-2:]\n    B,num_depth,Hr,Wr = depth_values.shape\n\n    with torch.no_grad():\n        proj = torch.matmul(src_proj, torch.inverse(ref_proj))\n        rot = proj[:, :3, :3]  # [B,3,3]\n        trans = proj[:, :3, 3:4]  # [B,3,1]\n\n        y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device),\n                               torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)])\n        y = y.reshape(Hr*Wr)\n        x = x.reshape(Hr*Wr)\n        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]\n        xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1)  # [B, 3, H*W]\n        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]\n        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1)  # [B, 3, Ndepth, H*W]\n        proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1)  # [B, 3, Ndepth, H*W]\n        # FIXME divide 0\n        temp = proj_xyz[:, 2:3, :, :]\n        temp[temp==0] = 1e-9\n        proj_xy = proj_xyz[:, :2, :, :] / temp  # [B, 2, Ndepth, H*W]\n        # proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]\n\n        proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1\n        proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1\n        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]\n        grid = proj_xy\n    if len(src_fea.shape)==4:\n        warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True)\n        warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr)\n    elif len(src_fea.shape)==5:\n        warped_src_fea = []\n        for d in range(src_fea.shape[2]):\n            warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True))\n        warped_src_fea = torch.stack(warped_src_fea, dim=2)\n\n    return warped_src_fea\n\n\ndef init_inverse_range(cur_depth, ndepths, device, dtype, H, W):\n    inverse_depth_min = 1. / cur_depth[:, 0]  # (B,)\n    inverse_depth_max = 1. / cur_depth[:, -1]\n    itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W)  / (ndepths - 1)  # 1 D H W\n    inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv\n\n    return 1./inverse_depth_hypo\n\n\ndef schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W):\n    # cur_depth_min, (B, H, W)\n    # cur_depth_max: (B, H, W)\n    itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2)  / (ndepths - 1)  # 1 D H W\n\n    inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv  # B D H W\n    inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1)\n    return 1./inverse_depth_hypo\n\n\n# --------------------------------------------------------------\n\n\ndef init_bn(module):\n    if module.weight is not None:\n        nn.init.ones_(module.weight)\n    if module.bias is not None:\n        nn.init.zeros_(module.bias)\n    return\n\n\ndef init_uniform(module, init_method):\n    if module.weight is not None:\n        if init_method == \"kaiming\":\n            nn.init.kaiming_uniform_(module.weight)\n        elif init_method == \"xavier\":\n            nn.init.xavier_uniform_(module.weight)\n    return\n\n\nclass ConvBnReLU3D(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n\n    def forward(self, x):\n        return F.relu(self.bn(self.conv(x)), inplace=True)\n\n\nclass Conv2d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 relu=True, bn_momentum=0.1, init_method=\"xavier\", gn=False, group_channel=8, **kwargs):\n        super(Conv2d, self).__init__()\n        bn = not gn\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,\n                              bias=(not bn), **kwargs)\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None\n        self.relu = relu\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        else:\n            x = self.gn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)"
  },
  {
    "path": "models/utils/__init__.py",
    "content": "from models.utils.utils import *"
  },
  {
    "path": "models/utils/opts.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Options settings & configurations for GeoMVSNet.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport argparse\n\ndef get_opts():\n    parser = argparse.ArgumentParser(description=\"args\")\n\n    # global settings\n    parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'val'])\n    parser.add_argument('--which_dataset', default='dtu', choices=['dtu', 'tnt', 'blendedmvs'], help='which dataset for using')\n\n    parser.add_argument('--n_views', type=int, default=5, help='num of view')\n    parser.add_argument('--levels', type=int, default=4, help='num of stages')\n    parser.add_argument('--hypo_plane_num_stages', type=str, default=\"8,8,4,4\", help='num of hypothesis planes for each stage')\n    parser.add_argument('--depth_interal_ratio_stages', type=str, default=\"0.5,0.5,0.5,1\", help='depth interals for each stage')\n    parser.add_argument(\"--feat_base_channel\", type=int, default=8, help='channel num for base feature')\n    parser.add_argument(\"--reg_base_channel\", type=int, default=8, help='channel num for regularization')\n    parser.add_argument('--group_cor_dim_stages', type=str, default=\"8,8,4,4\", help='group correlation dim')\n\n    parser.add_argument('--batch_size', type=int, default=1, help='batch size for training')\n    parser.add_argument('--data_scale', type=str, choices=['mid', 'raw'], help='use mid or raw resolution')\n    parser.add_argument('--trainpath', help='data path for training')\n    parser.add_argument('--testpath', help='data path for testing')\n    parser.add_argument('--trainlist', help='data list for training')\n    parser.add_argument('--testlist', help='data list for testing')\n\n\n    # training config\n    parser.add_argument('--stage_lw', type=str, default=\"1,1,1,1\", help='loss weight for different stages')\n\n    parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')\n    parser.add_argument('--lr_scheduler', type=str, default='MS', help='scheduler for learning rate')\n    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')\n    parser.add_argument('--lrepochs', type=str, default=\"1,3,5,7,9,11,13,15:1.5\", help='epoch ids to downscale lr and the downscale rate')\n    parser.add_argument('--wd', type=float, default=0.0, help='weight decay')\n\n    parser.add_argument('--summary_freq', type=int, default=100, help='print and summary frequency')\n    parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')\n    parser.add_argument('--eval_freq', type=int, default=1, help='eval frequency')\n\n    parser.add_argument('--robust_train', action='store_true',help='robust training')\n\n    \n    # testing config\n    parser.add_argument('--split', type=str, choices=['intermediate', 'advanced'], help='intermediate|advanced for tanksandtemples')\n    parser.add_argument('--img_mode', type=str, default='resize', choices=['resize', 'crop'], help='image resolution matching strategy for TNT dataset')\n    parser.add_argument('--cam_mode', type=str, default='origin', choices=['origin', 'short_range'], help='camera parameter strategy for TNT dataset')\n\n    parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')\n    parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs')\n    parser.add_argument('--nolog', action='store_true', help='do not log into .log file')\n    parser.add_argument('--notensorboard', action='store_true', help='do not log into tensorboard')\n    parser.add_argument('--save_conf_all_stages', action='store_true', help='save confidence maps for all stages')\n    parser.add_argument('--outdir', default='./outputs', help='output dir')\n    parser.add_argument('--resume', action='store_true', help='continue to train the model')\n\n\n    # pytorch config\n    parser.add_argument('--device', default='cuda', help='device to use')\n    parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')\n    parser.add_argument('--pin_m', action='store_true', help='data loader pin memory')\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n\n    return parser.parse_args()"
  },
  {
    "path": "models/utils/utils.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Some useful utils.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport random\nimport numpy as np\n\nimport torch\nimport torchvision.utils as vutils\n\n\n# torch.no_grad warpper for functions\ndef make_nograd_func(func):\n    def wrapper(*f_args, **f_kwargs):\n        with torch.no_grad():\n            ret = func(*f_args, **f_kwargs)\n        return ret\n\n    return wrapper\n\n\n# convert a function into recursive style to handle nested dict/list/tuple variables\ndef make_recursive_func(func):\n    def wrapper(vars):\n        if isinstance(vars, list):\n            return [wrapper(x) for x in vars]\n        elif isinstance(vars, tuple):\n            return tuple([wrapper(x) for x in vars])\n        elif isinstance(vars, dict):\n            return {k: wrapper(v) for k, v in vars.items()}\n        else:\n            return func(vars)\n\n    return wrapper\n\n\n@make_recursive_func\ndef tensor2float(vars):\n    if isinstance(vars, float):\n        return vars\n    elif isinstance(vars, torch.Tensor):\n        return vars.data.item()\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2float\".format(type(vars)))\n\n\n@make_recursive_func\ndef tensor2numpy(vars):\n    if isinstance(vars, np.ndarray):\n        return vars\n    elif isinstance(vars, torch.Tensor):\n        return vars.detach().cpu().numpy().copy()\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2numpy\".format(type(vars)))\n\n\n@make_recursive_func\ndef tocuda(vars):\n    if isinstance(vars, torch.Tensor):\n        return vars.to(torch.device(\"cuda\"))\n    elif isinstance(vars, str):\n        return vars\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2numpy\".format(type(vars)))\n\n\ndef tb_save_scalars(logger, mode, scalar_dict, global_step):\n    scalar_dict = tensor2float(scalar_dict)\n    for key, value in scalar_dict.items():\n        if not isinstance(value, (list, tuple)):\n            name = '{}/{}'.format(mode, key)\n            logger.add_scalar(name, value, global_step)\n        else:\n            for idx in range(len(value)):\n                name = '{}/{}_{}'.format(mode, key, idx)\n                logger.add_scalar(name, value[idx], global_step)\n\n\ndef tb_save_images(logger, mode, images_dict, global_step):\n    images_dict = tensor2numpy(images_dict)\n\n    def preprocess(name, img):\n        if not (len(img.shape) == 3 or len(img.shape) == 4):\n            raise NotImplementedError(\"invalid img shape {}:{} in save_images\".format(name, img.shape))\n        if len(img.shape) == 3:\n            img = img[:, np.newaxis, :, :]\n        img = torch.from_numpy(img[:1])\n        return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True)\n\n    for key, value in images_dict.items():\n        if not isinstance(value, (list, tuple)):\n            name = '{}/{}'.format(mode, key)\n            logger.add_image(name, preprocess(name, value), global_step)\n        else:\n            for idx in range(len(value)):\n                name = '{}/{}_{}'.format(mode, key, idx)\n                logger.add_image(name, preprocess(name, value[idx]), global_step)\n\n\nclass DictAverageMeter(object):\n    def __init__(self):\n        self.data = {}\n        self.count = 0\n\n    def update(self, new_input):\n        self.count += 1\n        if len(self.data) == 0:\n            for k, v in new_input.items():\n                if not isinstance(v, float):\n                    raise NotImplementedError(\"invalid data {}: {}\".format(k, type(v)))\n                self.data[k] = v\n        else:\n            for k, v in new_input.items():\n                if not isinstance(v, float):\n                    raise NotImplementedError(\"invalid data {}: {}\".format(k, type(v)))\n                self.data[k] += v\n\n    def mean(self):\n        return {k: v / self.count for k, v in self.data.items()}\n\n\n# a wrapper to compute metrics for each image individually\ndef compute_metrics_for_each_image(metric_func):\n    def wrapper(depth_est, depth_gt, mask, *args):\n        batch_size = depth_gt.shape[0]\n        results = []\n        # compute result one by one\n        for idx in range(batch_size):\n            ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args)\n            results.append(ret)\n        return torch.stack(results).mean()\n\n    return wrapper\n\n\n@make_nograd_func\n@compute_metrics_for_each_image\ndef Thres_metrics(depth_est, depth_gt, mask, thres):\n    assert isinstance(thres, (int, float))\n    depth_est, depth_gt = depth_est[mask], depth_gt[mask]\n    errors = torch.abs(depth_est - depth_gt)\n    err_mask = errors > thres\n    return torch.mean(err_mask.float())\n\n\n# NOTE: please do not use this to build up training loss\n@make_nograd_func\n@compute_metrics_for_each_image\ndef AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None):\n    depth_est, depth_gt = depth_est[mask], depth_gt[mask]\n    error = (depth_est - depth_gt).abs()\n    if thres is not None:\n        error = error[(error >= float(thres[0])) & (error <= float(thres[1]))]\n        if error.shape[0] == 0:\n            return torch.tensor(0, device=error.device, dtype=error.dtype)\n    return torch.mean(error)\n\n\nimport torch.distributed as dist\ndef synchronize():\n    \"\"\"\n    Helper function to synchronize (barrier) among all processes when\n    using distributed training\n    \"\"\"\n    if not dist.is_available():\n        return\n    if not dist.is_initialized():\n        return\n    world_size = dist.get_world_size()\n    if world_size == 1:\n        return\n    dist.barrier()\n\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n    if not dist.is_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef reduce_scalar_outputs(scalar_outputs):\n    world_size = get_world_size()\n    if world_size < 2:\n        return scalar_outputs\n    with torch.no_grad():\n        names = []\n        scalars = []\n        for k in sorted(scalar_outputs.keys()):\n            names.append(k)\n            scalars.append(scalar_outputs[k])\n        scalars = torch.stack(scalars, dim=0)\n        dist.reduce(scalars, dst=0)\n        if dist.get_rank() == 0:\n            # only main process gets accumulated, so only divide by\n            # world_size in this case\n            scalars /= world_size\n        reduced_scalars = {k: v for k, v in zip(names, scalars)}\n\n    return reduced_scalars\n\n\nimport torch\nfrom bisect import bisect_right\nclass WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):\n    def __init__(\n        self,\n        optimizer,\n        milestones,\n        gamma=0.1,\n        warmup_factor=1.0 / 3,\n        warmup_iters=500,\n        warmup_method=\"linear\",\n        last_epoch=-1,\n    ):\n        if not list(milestones) == sorted(milestones):\n            raise ValueError(\n                \"Milestones should be a list of\" \" increasing integers. Got {}\",\n                milestones,\n            )\n\n        if warmup_method not in (\"constant\", \"linear\"):\n            raise ValueError(\n                \"Only 'constant' or 'linear' warmup_method accepted\"\n                \"got {}\".format(warmup_method)\n            )\n        self.milestones = milestones\n        self.gamma = gamma\n        self.warmup_factor = warmup_factor\n        self.warmup_iters = warmup_iters\n        self.warmup_method = warmup_method\n        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        warmup_factor = 1\n        if self.last_epoch < self.warmup_iters:\n            if self.warmup_method == \"constant\":\n                warmup_factor = self.warmup_factor\n            elif self.warmup_method == \"linear\":\n                alpha = float(self.last_epoch) / self.warmup_iters\n                warmup_factor = self.warmup_factor * (1 - alpha) + alpha\n        return [\n            base_lr\n            * warmup_factor\n            * self.gamma ** bisect_right(self.milestones, self.last_epoch)\n            for base_lr in self.base_lrs\n        ]\n\n    \ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)"
  },
  {
    "path": "outputs/visual.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"- @Description: Juputer notebook for visualizing depth maps.\\n\",\n    \"- @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\\n\",\n    \"- @Affiliation: Peking University (PKU)\\n\",\n    \"- @LastEditDate: 2023-09-07\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"ExecutionIndicator\": {\n     \"show\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import sys, os\\n\",\n    \"sys.path.append('../')\\n\",\n    \"import numpy as np\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import re\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def read_pfm(filename):\\n\",\n    \"    file = open(filename, 'rb')\\n\",\n    \"    color = None\\n\",\n    \"    width = None\\n\",\n    \"    height = None\\n\",\n    \"    scale = None\\n\",\n    \"    endian = None\\n\",\n    \"\\n\",\n    \"    header = file.readline().decode('utf-8').rstrip()\\n\",\n    \"    if header == 'PF':\\n\",\n    \"        color = True\\n\",\n    \"    elif header == 'Pf':\\n\",\n    \"        color = False\\n\",\n    \"    else:\\n\",\n    \"        raise Exception('Not a PFM file.')\\n\",\n    \"\\n\",\n    \"    dim_match = re.match(r'^(\\\\d+)\\\\s(\\\\d+)\\\\s$', file.readline().decode('utf-8'))\\n\",\n    \"    if dim_match:\\n\",\n    \"        width, height = map(int, dim_match.groups())\\n\",\n    \"    else:\\n\",\n    \"        raise Exception('Malformed PFM header.')\\n\",\n    \"\\n\",\n    \"    scale = float(file.readline().rstrip())\\n\",\n    \"    if scale < 0:  # little-endian\\n\",\n    \"        endian = '<'\\n\",\n    \"        scale = -scale\\n\",\n    \"    else:\\n\",\n    \"        endian = '>'  # big-endian\\n\",\n    \"\\n\",\n    \"    data = np.fromfile(file, endian + 'f')\\n\",\n    \"    shape = (height, width, 3) if color else (height, width)\\n\",\n    \"\\n\",\n    \"    data = np.reshape(data, shape)\\n\",\n    \"    data = np.flipud(data)\\n\",\n    \"    file.close()\\n\",\n    \"    return data, scale\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def read_depth(filename):\\n\",\n    \"    depth = read_pfm(filename)[0]\\n\",\n    \"    return np.array(depth, dtype=np.float32)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"assert False\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## DTU\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"ExecutionIndicator\": {\n     \"show\": true\n    },\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"exp_name = 'dtu/geomvsnet'\\n\",\n    \"depth_name = \\\"00000009.pfm\\\"\\n\",\n    \"\\n\",\n    \"scans = os.listdir(os.path.join(exp_name))\\n\",\n    \"scans = list(filter(lambda x: x.startswith(\\\"scan\\\"), scans))\\n\",\n    \"scans.sort(key=lambda x: int(x[4:]))\\n\",\n    \"for scan in scans:\\n\",\n    \"    depth_filename = os.path.join(exp_name, scan, \\\"depth_est\\\", depth_name)\\n\",\n    \"    if not os.path.exists(depth_filename): continue\\n\",\n    \"    depth = read_depth(depth_filename)\\n\",\n    \"\\n\",\n    \"    confidence_filename = os.path.join(exp_name, scan, \\\"confidence\\\", depth_name)\\n\",\n    \"    confidence = read_depth(confidence_filename)\\n\",\n    \"\\n\",\n    \"    print(scan, depth_name)\\n\",\n    \"\\n\",\n    \"    plt.figure(figsize=(12, 12))\\n\",\n    \"    plt.subplot(1, 2, 1)\\n\",\n    \"    plt.xticks([]), plt.yticks([]), plt.axis('off')\\n\",\n    \"    plt.imshow(depth, 'viridis',  vmin=500, vmax=830)\\n\",\n    \"\\n\",\n    \"    plt.subplot(1, 2, 2)\\n\",\n    \"    plt.xticks([]), plt.yticks([]), plt.axis('off')\\n\",\n    \"    plt.imshow(confidence, 'viridis')\\n\",\n    \"    plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## TNT\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"exp_name = './tnt/blend/geomvsnet/'\\n\",\n    \"depth_name = \\\"00000009.pfm\\\"\\n\",\n    \"\\n\",\n    \"with open(\\\"../datasets/lists/tnt/intermediate.txt\\\") as f:\\n\",\n    \"    scans_i = [line.rstrip() for line in f.readlines()]\\n\",\n    \"\\n\",\n    \"with open(\\\"../datasets/lists/tnt/advanced.txt\\\") as f:\\n\",\n    \"    scans_a = [line.rstrip() for line in f.readlines()]\\n\",\n    \"\\n\",\n    \"scans = scans_i + scans_a\\n\",\n    \"\\n\",\n    \"for scan in scans:\\n\",\n    \"\\n\",\n    \"    depth_filename = os.path.join(exp_name, scan, \\\"depth_est\\\", depth_name)\\n\",\n    \"    if not os.path.exists(depth_filename): continue\\n\",\n    \"    depth = read_depth(depth_filename)\\n\",\n    \"\\n\",\n    \"    print(scan, depth_name, depth.shape)\\n\",\n    \"\\n\",\n    \"    plt.figure(figsize=(12, 12))\\n\",\n    \"    plt.xticks([]), plt.yticks([]), plt.axis('off')\\n\",\n    \"    plt.imshow(depth, 'viridis', vmin=0, vmax=10)\\n\",\n    \"\\n\",\n    \"    plt.show()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.12\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"d253918f84404206ad3cf9c22ee3709ef6e34cbea610b0ac9787033d60da5e03\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.10.0\ntorchvision\nopencv-python\nnumpy==1.18.1\npillow\nscipy\ntensorboardX\nplyfile\nopen3d\njupyter\nnotebook"
  },
  {
    "path": "scripts/blend/train_blend.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"geomvsnet\"\n\nLOG_DIR=\"./checkpoints/blend/\"$THISNAME \nif [ ! -d $LOG_DIR ]; then\n    mkdir -p $LOG_DIR\nfi\n\nCUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \\\n    --which_dataset=\"blendedmvs\" --epochs=16 --logdir=$LOG_DIR \\\n    --trainpath=$BLENDEDMVS_ROOT --testpath=$BLENDEDMVS_ROOT \\\n    --trainlist=\"datasets/lists/blendedmvs/low_res_all.txt\" --testlist=\"datasets/lists/blendedmvs/val.txt\" \\\n    \\\n    --n_views=\"7\" --batch_size=2 --lr=0.001 --robust_train \\\n    --lr_scheduler=\"onecycle\""
  },
  {
    "path": "scripts/data_path.sh",
    "content": "#!/usr/bin/env bash\n\n# DTU\nDTU_TRAIN_ROOT=\"[/path/to/]dtu\"\nDTU_TEST_ROOT=\"[/path/to/]dtu-test\"\nDTU_QUANTITATIVE_ROOT=\"[/path/to/]dtu-evaluation\"\n\n# Tanks and Temples\nTNT_ROOT=\"[/path/to/]tnt\"\n\n# BlendedMVS\nBLENDEDMVS_ROOT=\"[/path/to/]blendmvs\""
  },
  {
    "path": "scripts/dtu/fusion_dtu.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"geomvsnet\"\nFUSION_METHOD=\"open3d\"\n\nLOG_DIR=\"./checkpoints/dtu/\"$THISNAME \nDTU_OUT_DIR=\"./outputs/dtu/\"$THISNAME\n\nif [ $FUSION_METHOD = \"pcd\" ] ; then\npython3 fusions/dtu/pcd.py ${@} \\\n    --testpath=$DTU_TEST_ROOT --testlist=\"datasets/lists/dtu/test.txt\" \\\n    --outdir=$DTU_OUT_DIR --logdir=$LOG_DIR --nolog \\\n    --num_worker=1 \\\n    \\\n    --thres_view=4 --conf=0.5 \\\n    \\\n    --plydir=$DTU_OUT_DIR\"/pcd_fusion_plys/\"\n    \nelif [ $FUSION_METHOD = \"gipuma\" ] ; then\n# source [/path/to/]anaconda3/etc/profile.d/conda.sh\n# conda activate fusibile\nCUDA_VISIBLE_DEVICES=0 python2 fusions/dtu/gipuma.py \\\n    --root_dir=$DTU_TEST_ROOT --list_file=\"datasets/lists/dtu/test.txt\" \\\n    --fusibile_exe_path=\"fusions/fusibile\" --out_folder=\"fusibile_fused\" \\\n    --depth_folder=$DTU_OUT_DIR \\\n    --downsample_factor=1 \\\n    \\\n    --prob_threshold=0.5 --disp_threshold=0.25 --num_consistent=3 \\\n    \\\n    --plydir=$DTU_OUT_DIR\"/gipuma_fusion_plys/\"\n\nelif [ $FUSION_METHOD = \"open3d\" ] ; then\nCUDA_VISIBLE_DEVICES=0 python fusions/dtu/_open3d.py --device=\"cuda\" \\\n    --root_path=$DTU_TEST_ROOT \\\n    --depth_path=$DTU_OUT_DIR \\\n    --data_list=\"datasets/lists/dtu/test.txt\" \\\n    \\\n    --prob_thresh=0.3 --dist_thresh=0.2 --num_consist=4 \\\n    \\\n    --ply_path=$DTU_OUT_DIR\"/open3d_fusion_plys/\"\n\nfi\n"
  },
  {
    "path": "scripts/dtu/matlab_quan_dtu.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nOUTNAME=\"geomvsnet\"\n\nFUSIONMETHOD=\"open3d\"\n\n# Evaluation\necho \"<<<<<<<<<< start parallel evaluation\"\nMETHOD='mvsnet'\nPLYPATH='../../../outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_fusion_plys/'\nRESULTPATH='../../../outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/'\nLOGPATH='outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/'$OUTNAME'.log'\n\nmkdir -p 'outputs/dtu/'$OUTNAME'/'$FUSIONMETHOD'_quantitative/'\n\nset_array=(1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118)\n\nnum_at_once=2   # 1 2 4 5 7 11 22\ntimes=`expr $((${#set_array[*]} / $num_at_once))`\nremain=`expr $((${#set_array[*]} - $num_at_once * $times))`\nthis_group_num=0\npos=0\n\nfor ((t=0; t<$times; t++))\ndo\n    if [ \"$t\" -ge `expr $(($times-$remain))` ] ; then\n        this_group_num=`expr $(($num_at_once + 1))`\n    else\n        this_group_num=$num_at_once\n    fi\n    \n    for set in \"${set_array[@]:pos:this_group_num}\"\n    do\n        matlab -nodesktop -nosplash -r \"cd datasets/evaluations/dtu_parallel; dataPath='$DTU_QUANTITATIVE_ROOT'; plyPath='$PLYPATH'; resultsPath='$RESULTPATH'; method_string='$METHOD'; thisset='$set'; BaseEvalMain_web\" &\n    done\n    wait\n\n    pos=`expr $(($pos + $this_group_num))`\n\ndone\nwait\n\n\nSET=[1,4,9,10,11,12,13,15,23,24,29,32,33,34,48,49,62,75,77,110,114,118]\n\nmatlab -nodesktop -nosplash -r \"cd datasets/evaluations/dtu_parallel; resultsPath='$RESULTPATH'; method_string='$METHOD'; set='$SET'; ComputeStat_web\" > $LOGPATH"
  },
  {
    "path": "scripts/dtu/test_dtu.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"geomvsnet\"\nBESTEPOCH=\"geomvsnet_release\"\n\nLOG_DIR=\"./checkpoints/dtu/\"$THISNAME \nDTU_CKPT_FILE=$LOG_DIR\"/model_\"$BESTEPOCH\".ckpt\"\nDTU_OUT_DIR=\"./outputs/dtu/\"$THISNAME\n\nCUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \\\n    --which_dataset=\"dtu\" --loadckpt=$DTU_CKPT_FILE --batch_size=1 \\\n    --outdir=$DTU_OUT_DIR --logdir=$LOG_DIR --nolog \\\n    --testpath=$DTU_TEST_ROOT --testlist=\"datasets/lists/dtu/test.txt\" \\\n    \\\n    --data_scale=\"raw\" --n_views=\"5\""
  },
  {
    "path": "scripts/dtu/train_dtu.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"geomvsnet\"\n\nLOG_DIR=\"./checkpoints/dtu/\"$THISNAME \nif [ ! -d $LOG_DIR ]; then\n    mkdir -p $LOG_DIR\nfi\n\nCUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \\\n    --which_dataset=\"dtu\" --epochs=16 --logdir=$LOG_DIR \\\n    --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \\\n    --trainlist=\"datasets/lists/dtu/train.txt\" --testlist=\"datasets/lists/dtu/test.txt\" \\\n    \\\n    --data_scale=\"mid\" --n_views=\"5\" --batch_size=4 --lr=0.002 --robust_train \\\n    --lrepochs=\"1,3,5,7,9,11,13,15:1.5\""
  },
  {
    "path": "scripts/dtu/train_dtu_raw.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"geomvsnet_raw\"\n\nLOG_DIR=\"./checkpoints/dtu/\"$THISNAME \nif [ ! -d $LOG_DIR ]; then\n    mkdir -p $LOG_DIR\nfi\n\nCUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 train.py ${@} \\\n    --which_dataset=\"dtu\" --epochs=16 --logdir=$LOG_DIR \\\n    --trainpath=$DTU_TRAIN_ROOT --testpath=$DTU_TRAIN_ROOT \\\n    --trainlist=\"datasets/lists/dtu/train.txt\" --testlist=\"datasets/lists/dtu/test.txt\" \\\n    \\\n    --data_scale=\"raw\" --n_views=\"5\" --batch_size=1 --lr=0.0005 --robust_train \\\n    --lrepochs=\"1,3,5,7,9,11,13,15:1.5\""
  },
  {
    "path": "scripts/tnt/fusion_tnt.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"blend/geomvsnet\"\n\nLOG_DIR=\"./checkpoints/tnt/\"$THISNAME \nTNT_OUT_DIR=\"./outputs/tnt/\"$THISNAME\n\n# Intermediate\npython3 fusions/tnt/dypcd.py ${@} \\\n    --root_dir=$TNT_ROOT --list_file=\"datasets/lists/tnt/intermediate.txt\" --split=\"intermediate\" \\\n    --out_dir=$TNT_OUT_DIR --ply_path=$TNT_OUT_DIR\"/dypcd_fusion_plys\" \\\n    --img_mode=\"resize\" --cam_mode=\"origin\" --single_processor \n\n# Advanced\npython3 fusions/tnt/dypcd.py ${@} \\\n    --root_dir=$TNT_ROOT --list_file=\"datasets/lists/tnt/advanced.txt\" --split=\"advanced\" \\\n    --out_dir=$TNT_OUT_DIR --ply_path=$TNT_OUT_DIR\"/dypcd_fusion_plys\" \\\n    --img_mode=\"resize\" --cam_mode=\"origin\" --single_processor"
  },
  {
    "path": "scripts/tnt/test_tnt.sh",
    "content": "#!/usr/bin/env bash\nsource scripts/data_path.sh\n\nTHISNAME=\"blend/geomvsnet\"\nBESTEPOCH=\"15\"\n\nLOG_DIR=\"./checkpoints/\"$THISNAME\nCKPT_FILE=$LOG_DIR\"/model_\"$BESTEPOCH\".ckpt\"\nTNT_OUT_DIR=\"./outputs/tnt/\"$THISNAME\n\n# Intermediate\nCUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \\\n    --which_dataset=\"tnt\" --loadckpt=$CKPT_FILE --batch_size=1 \\\n    --outdir=$TNT_OUT_DIR --logdir=$LOG_DIR --nolog \\\n    --testpath=$TNT_ROOT --testlist=\"datasets/lists/tnt/intermediate.txt\" --split=\"intermediate\" \\\n    \\\n    --n_views=\"11\" --img_mode=\"resize\" --cam_mode=\"origin\"\n\n# Advanced\nCUDA_VISIBLE_DEVICES=0 python3 test.py ${@} \\\n    --which_dataset=\"tnt\" --loadckpt=$CKPT_FILE --batch_size=1 \\\n    --outdir=$TNT_OUT_DIR --logdir=$LOG_DIR --nolog \\\n    --testpath=$TNT_ROOT --testlist=\"datasets/lists/tnt/advanced.txt\" --split=\"advanced\" \\\n    \\\n    --n_views=\"11\" --img_mode=\"resize\" --cam_mode=\"origin\""
  },
  {
    "path": "test.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Main process of network testing.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os, time, sys, gc, cv2, logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.data import DataLoader\n\nfrom datasets.data_io import *\nfrom datasets.dtu import DTUDataset\nfrom datasets.tnt import TNTDataset\n\nfrom models.geomvsnet import GeoMVSNet\nfrom models.utils import *\nfrom models.utils.opts import get_opts\n\n\ncudnn.benchmark = True\n\nargs = get_opts()\n\n\ndef test():\n    total_time = 0\n    with torch.no_grad():\n        for batch_idx, sample in enumerate(TestImgLoader):\n            sample_cuda = tocuda(sample)\n            start_time = time.time()\n\n            # @Note GeoMVSNet main\n            outputs = model(\n                sample_cuda[\"imgs\"], \n                sample_cuda[\"proj_matrices\"], sample_cuda[\"intrinsics_matrices\"], \n                sample_cuda[\"depth_values\"], \n                sample[\"filename\"]\n            )\n\n            end_time = time.time()\n            total_time += end_time - start_time\n            outputs = tensor2numpy(outputs)\n            del sample_cuda\n\n            filenames = sample[\"filename\"]\n            cams = sample[\"proj_matrices\"][\"stage{}\".format(args.levels)].numpy()\n            imgs = sample[\"imgs\"]\n            logger.info('Iter {}/{}, Time:{:.3f} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape))\n\n\n            for filename, cam, img, depth_est, photometric_confidence in zip(filenames, cams, imgs, outputs[\"depth\"], outputs[\"photometric_confidence\"]):\n                img = img[0].numpy()    # ref view\n                cam = cam[0]            # ref cam\n\n                depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm'))\n                confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm'))\n                cam_filename = os.path.join(args.outdir, filename.format('cams', '_cam.txt'))\n                img_filename = os.path.join(args.outdir, filename.format('images', '.jpg'))\n                os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True)\n                os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True)\n                if args.which_dataset == 'dtu':\n                    os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True)\n                    os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True)\n                \n                # save depth maps\n                save_pfm(depth_filename, depth_est)\n\n                # save confidence maps\n                confidence_list = [outputs['stage{}'.format(i)]['photometric_confidence'].squeeze(0) for i in range(1,5)]\n                photometric_confidence = confidence_list[-1]\n                if not args.save_conf_all_stages:\n                    save_pfm(confidence_filename, photometric_confidence) \n                else:\n                    for stage_idx, photometric_confidence in enumerate(confidence_list):\n                        if stage_idx != args.levels - 1:\n                            confidence_filename = os.path.join(args.outdir, filename.format('confidence', \"_stage\"+str(stage_idx)+'.pfm'))\n                        else:\n                            confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm'))\n                        save_pfm(confidence_filename, photometric_confidence) \n\n                # save cams, img\n                if args.which_dataset == 'dtu':\n                    write_cam(cam_filename, cam)\n                    img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8)\n                    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n                    cv2.imwrite(img_filename, img_bgr)\n\n    torch.cuda.empty_cache()\n    gc.collect()\n    return total_time, len(TestImgLoader)\n\n\ndef initLogger():\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time()))\n\n    if args.which_dataset == 'tnt':\n        logfile = os.path.join(args.logdir, 'TNT-test-' + curTime + '.log')\n    else:\n        logfile = os.path.join(args.logdir, 'test-' + curTime + '.log')\n    \n    formatter = logging.Formatter(\"%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s\")\n    if not args.nolog:\n        fileHandler = logging.FileHandler(logfile, mode='a')\n        fileHandler.setFormatter(formatter)\n        logger.addHandler(fileHandler)\n    consoleHandler = logging.StreamHandler(sys.stdout)\n    consoleHandler.setFormatter(formatter)\n    logger.addHandler(consoleHandler)\n    logger.info(\"Logger initialized.\")\n    logger.info(\"Writing logs to file: {}\".format(logfile))\n    logger.info(\"Current time: {}\".format(curTime))\n\n    settings_str = \"All settings:\\n\"\n    for k,v in vars(args).items(): \n        settings_str += '{0}: {1}\\n'.format(k,v)\n    logger.info(settings_str)\n\n    return logger\n\n\nif __name__ == '__main__':\n    logger = initLogger()\n\n    # dataset, dataloader\n    if args.which_dataset == 'dtu':\n        test_dataset = DTUDataset(args.testpath, args.testlist, \"test\", args.n_views, max_wh=(1600, 1200))\n    elif args.which_dataset == 'tnt':\n        test_dataset = TNTDataset(args.testpath, args.testlist, split=args.split, n_views=args.n_views, img_wh=(-1, 1024), cam_mode=args.cam_mode, img_mode=args.img_mode)\n\n    TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)\n\n    # @Note GeoMVSNet model\n    model = GeoMVSNet(\n        levels=args.levels, \n        hypo_plane_num_stages=[int(n) for n in args.hypo_plane_num_stages.split(\",\")], \n        depth_interal_ratio_stages=[float(ir) for ir in args.depth_interal_ratio_stages.split(\",\")],\n        feat_base_channel=args.feat_base_channel, \n        reg_base_channel=args.reg_base_channel,\n        group_cor_dim_stages=[int(n) for n in args.group_cor_dim_stages.split(\",\")],\n    )\n    \n    logger.info(\"loading model {}\".format(args.loadckpt))\n    state_dict = torch.load(args.loadckpt, map_location=torch.device(\"cpu\"))\n    model.load_state_dict(state_dict['model'], strict=False)\n\n    model.cuda()\n    model.eval()\n\n    test()"
  },
  {
    "path": "train.py",
    "content": "# -*- coding: utf-8 -*-\n# @Description: Main process of network training & evaluation.\n# @Author: Zhe Zhang (doublez@stu.pku.edu.cn)\n# @Affiliation: Peking University (PKU)\n# @LastEditDate: 2023-09-07\n\nimport os, sys, time, gc, datetime, logging, json\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim as optim\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader\nfrom tensorboardX import SummaryWriter\n\nfrom datasets.dtu import DTUDataset\nfrom datasets.blendedmvs import BlendedMVSDataset\n\nfrom models.geomvsnet import GeoMVSNet\nfrom models.loss import geomvsnet_loss\nfrom models.utils import *\nfrom models.utils.opts import get_opts\n\n\ncudnn.benchmark = True\nnum_gpus = int(os.environ[\"WORLD_SIZE\"]) if \"WORLD_SIZE\" in os.environ else 1\nis_distributed = num_gpus > 1\n\nargs = get_opts()\n\n\ndef train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args):\n    if args.lr_scheduler == 'MS':\n        milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')]\n        lr_gamma = 1 / float(args.lrepochs.split(':')[1])\n        lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500, last_epoch=len(TrainImgLoader) * start_epoch - 1)\n    elif args.lr_scheduler == 'cos':\n        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(args.epochs*len(TrainImgLoader)), eta_min=0)\n    elif args.lr_scheduler == 'onecycle':\n        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=int(args.epochs*len(TrainImgLoader)))\n    elif args.lr_scheduler == 'lambda':\n        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** ((epoch-1) / len(TrainImgLoader)), last_epoch=len(TrainImgLoader)*start_epoch-1)\n\n\n    for epoch_idx in range(start_epoch, args.epochs):\n        logger.info('Epoch {}:'.format(epoch_idx))\n        global_step = len(TrainImgLoader) * epoch_idx\n\n        # training\n        for batch_idx, sample in enumerate(TrainImgLoader):\n            start_time = time.time()\n            global_step = len(TrainImgLoader) * epoch_idx + batch_idx\n            do_summary = global_step % args.summary_freq == 0\n            loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args)\n            lr_scheduler.step()\n            if (not is_distributed) or (dist.get_rank() == 0):\n                if do_summary:\n                    if not args.notensorboard:\n                        tb_save_scalars(tb_writer, 'train', scalar_outputs, global_step)\n                        tb_save_images(tb_writer, 'train', image_outputs, global_step)\n                    logger.info(\"Epoch {}/{}, Iter {}/{}, 2mm_err={:.3f} | lr={:.6f}, train_loss={:.3f}, abs_err={:.3f}, pw_loss={:.3f}, dds_loss={:.3f}, time={:.3f}\".format(\n                           epoch_idx, args.epochs, batch_idx, len(TrainImgLoader),\n                           scalar_outputs[\"thres2mm_error\"],\n                           optimizer.param_groups[0][\"lr\"], \n                           loss,\n                           scalar_outputs[\"abs_depth_error\"],\n                           scalar_outputs[\"s3_pw_loss\"],\n                           scalar_outputs[\"s3_dds_loss\"],\n                           time.time() - start_time))\n                del scalar_outputs, image_outputs\n\n        # save checkpoint\n        if (not is_distributed) or (dist.get_rank() == 0):\n            if ((epoch_idx + 1) % args.save_freq == 0) or (epoch_idx == args.epochs-1):\n                torch.save({\n                    'epoch': epoch_idx,\n                    'model': model.module.state_dict(),\n                    'optimizer': optimizer.state_dict()},\n                    \"{}/model_{:0>2}.ckpt\".format(args.logdir, epoch_idx))  \n        gc.collect()\n\n        # testing\n        if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1):\n            avg_test_scalars = DictAverageMeter()\n            for batch_idx, sample in enumerate(TestImgLoader):\n                start_time = time.time()\n                global_step = len(TrainImgLoader) * epoch_idx + batch_idx\n                do_summary = global_step % args.summary_freq == 0\n                loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args)\n                if (not is_distributed) or (dist.get_rank() == 0):\n                    if do_summary:\n                        if not args.notensorboard:\n                            tb_save_scalars(tb_writer, 'test', scalar_outputs, global_step)\n                            tb_save_images(tb_writer, 'test', image_outputs, global_step)\n                        logger.info(\n                            \"Epoch {}/{}, Iter {}/{}, 2mm_err={:.3f} | lr={:.6f}, test_loss={:.3f}, abs_err={:.3f}, pw_loss={:.3f}, dds_loss={:.3f}, time={:.3f}\".format(\n                            epoch_idx, args.epochs, batch_idx, len(TestImgLoader),\n                            scalar_outputs[\"thres2mm_error\"],\n                            optimizer.param_groups[0][\"lr\"], \n                            loss,\n                            scalar_outputs[\"abs_depth_error\"],\n                            scalar_outputs[\"s3_pw_loss\"],\n                            scalar_outputs[\"s3_dds_loss\"],\n                            time.time() - start_time))\n                    avg_test_scalars.update(scalar_outputs)\n                    del scalar_outputs, image_outputs\n\n            if (not is_distributed) or (dist.get_rank() == 0):\n                if not args.notensorboard:\n                    tb_save_scalars(tb_writer, 'fulltest', avg_test_scalars.mean(), global_step)\n                logger.info(\"avg_test_scalars: \" + json.dumps(avg_test_scalars.mean()))\n            gc.collect()\n\n\ndef train_sample(model, model_loss, optimizer, sample, args):\n    model.train()\n    optimizer.zero_grad()\n\n    sample_cuda = tocuda(sample)\n    depth_gt_ms, mask_ms = sample_cuda[\"depth\"], sample_cuda[\"mask\"]\n    depth_gt, mask = depth_gt_ms[\"stage{}\".format(args.levels)], mask_ms[\"stage{}\".format(args.levels)]\n\n    # @Note GeoMVSNet main\n    outputs = model(\n        sample_cuda[\"imgs\"], \n        sample_cuda[\"proj_matrices\"], sample_cuda[\"intrinsics_matrices\"], \n        sample_cuda[\"depth_values\"]\n    )\n\n    depth_est = outputs[\"depth\"]\n\n    loss, epe, pw_loss_stages, dds_loss_stages = model_loss(\n        outputs, depth_gt_ms, mask_ms, \n        stage_lw=[float(e) for e in args.stage_lw.split(\",\") if e], depth_values=sample_cuda[\"depth_values\"]\n    )\n\n    loss.backward()\n    optimizer.step()\n\n    scalar_outputs = {\n        \"loss\": loss,\n        \"epe\": epe,\n        \"s0_pw_loss\": pw_loss_stages[0],\n        \"s1_pw_loss\": pw_loss_stages[1],\n        \"s2_pw_loss\": pw_loss_stages[2],\n        \"s3_pw_loss\": pw_loss_stages[3],\n        \"s0_dds_loss\": dds_loss_stages[0],\n        \"s1_dds_loss\": dds_loss_stages[1],\n        \"s2_dds_loss\": dds_loss_stages[2],\n        \"s3_dds_loss\": dds_loss_stages[3],\n        \"abs_depth_error\": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),\n        \"thres2mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),\n        \"thres4mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),\n        \"thres8mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),\n    }\n\n    image_outputs = {\n        \"depth_est\": depth_est * mask,\n        \"depth_est_nomask\": depth_est,\n        \"depth_gt\": sample[\"depth\"][\"stage1\"],\n        \"ref_img\": sample[\"imgs\"][0],\n        \"mask\": sample[\"mask\"][\"stage1\"],\n        \"errormap\": (depth_est - depth_gt).abs() * mask,\n    }\n\n    if is_distributed:\n        scalar_outputs = reduce_scalar_outputs(scalar_outputs)\n\n    return tensor2float(scalar_outputs[\"loss\"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)\n\n\n@make_nograd_func\ndef test_sample_depth(model, model_loss, sample, args):\n    if is_distributed:\n        model_eval = model.module\n    else:\n        model_eval = model\n    model_eval.eval()\n\n    sample_cuda = tocuda(sample)\n    depth_gt_ms, mask_ms = sample_cuda[\"depth\"], sample_cuda[\"mask\"]\n    depth_gt, mask = depth_gt_ms[\"stage{}\".format(args.levels)], mask_ms[\"stage{}\".format(args.levels)]\n\n    outputs = model_eval(\n        sample_cuda[\"imgs\"], \n        sample_cuda[\"proj_matrices\"], sample_cuda[\"intrinsics_matrices\"], \n        sample_cuda[\"depth_values\"]\n    )\n    \n    depth_est = outputs[\"depth\"]\n\n    loss, epe, pw_loss_stages, dds_loss_stages = model_loss(\n        outputs, depth_gt_ms, mask_ms, \n        stage_lw=[float(e) for e in args.stage_lw.split(\",\") if e], depth_values=sample_cuda[\"depth_values\"]\n    )\n    \n    scalar_outputs = {\n        \"loss\": loss,\n        \"epe\": epe,\n        \"s0_pw_loss\": pw_loss_stages[0],\n        \"s1_pw_loss\": pw_loss_stages[1],\n        \"s2_pw_loss\": pw_loss_stages[2],\n        \"s3_pw_loss\": pw_loss_stages[3],\n        \"s0_dds_loss\": dds_loss_stages[0],\n        \"s1_dds_loss\": dds_loss_stages[1],\n        \"s2_dds_loss\": dds_loss_stages[2],\n        \"s3_dds_loss\": dds_loss_stages[3],\n        \"abs_depth_error\": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),\n        \"thres2mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),\n        \"thres4mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),\n        \"thres8mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),\n    }\n\n    image_outputs = {\n        \"depth_est\": depth_est * mask,\n        \"depth_est_nomask\": depth_est,\n        \"depth_gt\": sample[\"depth\"][\"stage1\"],\n        \"ref_img\": sample[\"imgs\"][0],\n        \"mask\": sample[\"mask\"][\"stage1\"],\n        \"errormap\": (depth_est - depth_gt).abs() * mask\n    }\n\n    if is_distributed:\n        scalar_outputs = reduce_scalar_outputs(scalar_outputs)\n\n    return tensor2float(scalar_outputs[\"loss\"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)\n\n\ndef initLogger():\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    curTime = time.strftime('%Y%m%d-%H%M', time.localtime(time.time()))\n    logfile = os.path.join(args.logdir, 'train-' + curTime + '.log')\n    formatter = logging.Formatter(\"%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s\")\n    fileHandler = logging.FileHandler(logfile, mode='a')\n    fileHandler.setFormatter(formatter)\n    logger.addHandler(fileHandler)\n    consoleHandler = logging.StreamHandler(sys.stdout)\n    consoleHandler.setFormatter(formatter)\n    logger.addHandler(consoleHandler)\n    logger.info(\"Logger initialized.\")\n    logger.info(\"Writing logs to file: {}\".format(logfile))\n    logger.info(\"Current time: {}\".format(curTime))\n\n    settings_str = \"All settings:\\n\"\n    for k,v in vars(args).items(): \n        settings_str += '{0}: {1}\\n'.format(k,v)\n    logger.info(settings_str)\n\n    return logger\n\n\nif __name__ == '__main__':\n    logger = initLogger()\n\n    if args.resume:\n        assert args.mode == \"train\"\n        assert args.loadckpt is None\n\n    if is_distributed:\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n        synchronize()\n\n    set_random_seed(args.seed)\n    device = torch.device(args.device)\n\n\n    # tensorboard\n    if (not is_distributed) or (dist.get_rank() == 0):\n        if not os.path.isdir(args.logdir):\n            os.makedirs(args.logdir)\n        current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))\n        logger.info(\"current time \" + current_time_str)\n        logger.info(\"creating new summary file\")\n        if not args.notensorboard:\n            tb_writer = SummaryWriter(args.logdir)\n\n\n    # @Note GeoMVSNet model\n    model = GeoMVSNet(\n        levels=args.levels, \n        hypo_plane_num_stages=[int(n) for n in args.hypo_plane_num_stages.split(\",\")], \n        depth_interal_ratio_stages=[float(ir) for ir in args.depth_interal_ratio_stages.split(\",\")],\n        feat_base_channel=args.feat_base_channel, \n        reg_base_channel=args.reg_base_channel,\n        group_cor_dim_stages=[int(n) for n in args.group_cor_dim_stages.split(\",\")],\n    )\n    model.to(device)\n\n    model_loss = geomvsnet_loss\n\n    # optimizer\n    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd)\n\n\n    # load parameters\n    start_epoch = 0\n    if args.resume:\n        saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(\".ckpt\")]\n        saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n        loadckpt = os.path.join(args.logdir, saved_models[-1])\n        logger.info(\"resuming: \" + loadckpt)\n        state_dict = torch.load(loadckpt, map_location=torch.device(\"cpu\"))\n        model.load_state_dict(state_dict['model'])\n        optimizer.load_state_dict(state_dict['optimizer'])\n        start_epoch = state_dict['epoch'] + 1\n\n\n    # distributed\n    if (not is_distributed) or (dist.get_rank() == 0):\n        logger.info(\"start at epoch {}\".format(start_epoch))\n        logger.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))\n\n    if is_distributed:\n        if dist.get_rank() == 0:\n            logger.info(\"Let's use {} GPUs in distributed mode!\".format(torch.cuda.device_count()))\n        model = torch.nn.parallel.DistributedDataParallel(\n            model, device_ids=[args.local_rank], output_device=args.local_rank,\n            find_unused_parameters=True,\n        )\n    else:\n        if torch.cuda.is_available():\n            logger.info(\"Let's use {} GPUs in parallel mode.\".format(torch.cuda.device_count()))\n            model = nn.DataParallel(model)\n\n\n    # dataset, dataloader\n    if args.which_dataset == \"dtu\":\n        train_dataset = DTUDataset(args.trainpath, args.trainlist, \"train\", args.n_views, data_scale=args.data_scale, robust_train=args.robust_train)\n        test_dataset = DTUDataset(args.testpath, args.testlist, \"val\", args.n_views, data_scale=args.data_scale)\n    elif args.which_dataset == \"blendedmvs\":\n        train_dataset = BlendedMVSDataset(args.trainpath, args.trainlist, \"train\", args.n_views, img_wh=(768, 576), robust_train=args.robust_train, augment=False)\n        test_dataset = BlendedMVSDataset(args.testpath, args.testlist, \"val\", args.n_views, img_wh=(768, 576))\n\n    if is_distributed:\n        train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())\n        test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())\n\n        TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=8, drop_last=True, pin_memory=args.pin_m)\n        TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=8, drop_last=False, pin_memory=args.pin_m)\n    else:\n        TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=args.pin_m)\n        TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=8, drop_last=False, pin_memory=args.pin_m)\n\n    train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args)"
  }
]