[
  {
    "path": ".gitignore",
    "content": "# Autogenerated folders\n__pycache__\nlogs\ntest\ndata\n\n# IDEs generated folders\n.spyproject\nvenv/\n.idea/\n__MACOSX/\n**/.DS_Store\n\n# other\npretrained\n*.pth"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2016-2019 VRG, CTU Prague\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Deep Visual Geo-localization Benchmark\nThis is the official repository for the CVPR 2022 Oral paper [Deep Visual Geo-localization Benchmark](https://arxiv.org/abs/2204.03444).\nIt can be used to reproduce results from the paper, and to compute a wide range of experiments, by changing the components of a Visual Geo-localization pipeline.\n\n<img src=\"https://github.com/gmberton/gmberton.github.io/blob/main/images/vg_system.png\" width=\"90%\">\n\n## Setup\nBefore you begin experimenting with this toolbox, your dataset should be organized in a directory tree as such:\n\n```\n.\n├── benchmarking_vg\n└── datasets_vg\n    └── datasets\n        └── pitts30k\n            └── images\n                ├── train\n                │   ├── database\n                │   └── queries\n                ├── val\n                │   ├── database\n                │   └── queries\n                └── test\n                    ├── database\n                    └── queries\n```\nThe [VPR-datasets-downloader](https://github.com/gmberton/VPR-datasets-downloader) repo can be used to download a number of datasets. Detailed instructions on how to download datasets are in the repo. Note that many datasets are available, and _pitts30k_ is just an example.\n\n## Running experiments\n### Basic experiment\nFor a basic experiment run\n\n`$ python3 train.py --dataset_name=pitts30k`\n\nthis will train a ResNet-18 + NetVLAD on Pitts30k.\nThe experiment creates a folder named `./logs/default/YYYY-MM-DD_HH-mm-ss`, where checkpoints are saved, as well as an `info.log` file with training logs and other information, such as model size, FLOPs and descriptors dimensionality.\n\n### Architectures and mining\nYou can replace the backbone and the aggregation as such\n\n`$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem`\n\nyou can easily use ResNets cropped at conv4 or conv5.\n#### Add a fully connected layer\nTo add a fully connected layer of dimension 2048 to GeM pooling:\n\n`$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem --fc_output_dim=2048`\n\n#### Add PCA\nTo add PCA to a NetVLAD layer just do:\n\n`$ python3 eval.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=netvlad --pca_dim=2048 --pca_dataset_folder=pitts30k/images/train`\n\nwhere _pca_dataset_folder_ points to the folder with the images used to compute PCA. In the paper we compute PCA's principal components on the train set as it showed best results. PCA is used only at test time.\n#### Evaluate trained models\nTo evaluate the trained model on other datasets (this example is with the St Lucia dataset), simply run\n\n`$ python3 eval.py --backbone=resnet50conv4 --aggregation=gem --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=st_lucia`\n\n#### Reproduce the results\nFinally, to reproduce our results, use the appropriate mining method: _full_ for _pitts30k_ and _partial_ for _msls_ as such:\n\n`$ python3 train.py --dataset_name=pitts30k --mining=full`\n\nAs simple as this, you can replicate all results from tables 3, 4, 5 of the main paper, as well as tables 2, 3, 4 of the supplementary.\n\n### Resize\nTo resize the images simply pass the parameters _resize_ with the target resolution. For example, 80% of resolution to the full _pitts30k_ images, would be 384, 512, because the full images are 480, 640:\n\n`$ python3 train.py --dataset_name=pitts30k --resize=384 512`\n\n### Query  pre/post-processing  and  predictions  refinement\nWe gather all such methods under the _test_method_ parameter. The available methods are _hard_resize_, _single_query_, _central_crop_, _five_crops_mean_, _nearest_crop_ and _majority_voting_.\nAlthough _hard_resize_ is the default, in most datasets it doesn't apply any transformation at all (see the paper for more information), because all images have the same resolution.\n\n`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=tokyo247 --test_method=nearest_crop`\n\n### Data augmentation\nYou can reproduce all data augmentation techniques from the paper with simple commands, for example:\n\n`$ python3 train.py --dataset_name=pitts30k --horizontal_flipping --saturation 2 --brightness 1`\n\n### Off-the-shelf models trained on Landmark Recognition datasets\nThe code allows to automatically download and use models trained on Landmark Recognition datasets from popular repositories: [radenovic](https://github.com/filipradenovic/cnnimageretrieval-pytorch) and [naver](https://github.com/naver/deep-image-retrieval).\nThese repos offer ResNets-50/101 with GeM and FC 2048 trained on such datasets, and can be used as such:\n\n`$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=r101l4 --aggregation=gem --fc_output_dim=2048`\n\n`$ python eval.py --dataset_name=pitts30k --off_the_shelf=naver --l2=none --backbone=r101l4 --aggregation=gem --fc_output_dim=2048`\n\n### Using pretrained networks on other datasets\nCheck out our [pretrain_vg](https://github.com/rm-wu/pretrain_vg) repo which we use to train such models.\nYou can automatically download and train on those models as such\n\n`$ python train.py --dataset_name=pitts30k --pretrained=places`\n\n### Changing the threshold distance\nYou can use a different distance than the default 25 meters as simply as this (for example to 100 meters):\n\n`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --val_positive_dist_threshold=100`\n\n### Changing the recall values (R@N)\nBy default the toolbox computes recalls@ 1, 5, 10, 20, but you can compute other recalls as such:\n\n`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --recall_values 1 5 10 15 20 50 100`\n\n### Model Zoo\nWe are currently exploring hosting options, so this is a partial list of models. More models will be added soon!!\n\n<details>\n     <summary><b>Pretrained models with different backbones</b></summary></br>\n    Pretained networks employing different backbones.</br></br>\n\t<table>\n\t\t<tr>\n\t\t\t<th rowspan=2>Model</th>\n\t\t\t<th colspan=\"3\">Training on Pitts30k</th>\n\t\t\t<th colspan=\"3\">Training on MSLS</th>\n\t \t</tr>\n\t \t<tr>\n\t  \t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t\t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td>vgg16-gem</td>\n\t\t\t<td>78.5</td> <td>43.4</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1-e9v_mynIX5XBsdtN_mG9tz5-nA5PWiq/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>70.2</td> <td>66.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1GqgO-qG-WNJXWty43KgvDtW0OpG0Wrq-/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t \t\t<td>resnet18-gem</td>\n\t\t\t<td>77.8</td> <td>35.3</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1R66NYeLlxBIqLviUVL9XPZkrtmyMn_tU/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>71.6</td> <td>65.3</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1IH0d_ME2kU3pagsKhx5ZfRfyWriErajn/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> resnet50-gem </td>\n\t\t\t<td>82.0</td> <td>38.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1esgXzRFvDFHrMnwwR3GlTnErXjFNrYV7/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>77.4</td> <td>72.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1uuIYJN4N7lQqqsN32pbZwjhz5Xvv3zr-/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> resnet101-gem </td>\n\t\t\t<td>82.4</td> <td>39.6</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1Sd-sezmbzOGbZcy3eqRnWH07eoJ7CM0X/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>77.2</td> <td>72.5</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1Iondvd8P3vb3piHFTA-RUgTFpqh0I31M/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> ViT(224)-CLS </td>\n\t\t\t<td> _ </td> <td> _ </td>\n\t\t\t<td> _ </td>\n\t\t\t<td> 80.4 </td> <td> 69.3 </td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1mPiIqFKnW1HWtXqKJhLLmIgMWoV14auG/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>vgg16-netvlad</td>\n\t\t\t<td>83.2</td> <td>50.9</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/14s7OZor6wrlGBKeXr0vKbPfTzlW9preM/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>79.0</td> <td>74.6</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1dwai3uNudjvns58JIyaf5CBRg4ojcWIW/view?usp=sharing\">[Link]</a</td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet18-netvlad</td>\n\t\t\t<td>86.4</td> <td>47.4</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1KFwonDQYdvzTAIILsOMjmLRUR76jXXvB/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>81.6</td> <td>75.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1_Ozq2TdvwLAJUwy7YH9l69GsfOU-MlFZ/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet50-netvlad</td>\n\t\t\t<td>86.0</td> <td>50.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1KL8HoAApOjJFETin7Q7u7IcsOvroKlSj/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.9</td> <td>76.9</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1krf0A6CeW8GqLqHWZ7dlSNJ9aTJ4dotF/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet101-netvlad</td>\n\t\t\t<td>86.5</td> <td>51.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1064kDJ0LPyWoU7J4bMvAa0lTNEhAEi8v/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.8</td> <td>77.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1rtPfsgfJ2Zoxs5uu7Ph1_qc7q-hIxJek/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td>cct384-netvlad</td>\n\t\t\t<td>85.0</td> <td>52.5</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1Rx0oG4PG9bEraIg4y7e6Z24Q6b_TGr5u/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.3</td> <td>85.1</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1wDZ6XRVYz6bcGe_p3Iiz2NfIe9MmZZMN/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t</table>\n    \n</details>\n\n<details>\n \t<summary><b>Pretrained models with different aggregation methods</b></summary></br>\n \tPretrained networks trained using different aggregation methods.</br></br>\n    <table>\n\t\t<tr>\n\t\t\t<th rowspan=2>Model</th>\n\t\t\t <th colspan=\"3\">Training on Pitts30k (R@1)</th>\n\t\t\t <th colspan=\"3\">Training on MSLS (R@1)</th>\n\t \t</tr>\n\t \t<tr>\n\t  \t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t\t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td>resnet50-gem</td>\n\t\t\t<td>82.0</td> <td>38.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1esgXzRFvDFHrMnwwR3GlTnErXjFNrYV7/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>77.4</td> <td>72.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1uuIYJN4N7lQqqsN32pbZwjhz5Xvv3zr-/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet50-gem-fc2048</td>\n\t\t\t<td>80.1</td> <td>33.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1GCbE4gzcRXMH8ETD2YCPo0I3suAXDr-y/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>79.2</td> <td>73.5</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1oSf11wAxaoEbjLnjfX0EWZ65dgccwdDD/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet50-gem-fc65536</td>\n\t\t\t<td>80.8</td> <td>35.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/19GjodUuAGKpac6WhIcfuy3tiPV1J-ikn/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>79.0</td> <td>74.4</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1OGwt651loL2vXnQYyABqitL39IEiXhag/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet50-netvlad</td>\n\t\t\t<td>86.0</td> <td>50.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1KL8HoAApOjJFETin7Q7u7IcsOvroKlSj/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.9</td> <td>76.9</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1krf0A6CeW8GqLqHWZ7dlSNJ9aTJ4dotF/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td>resnet50-crn</td>\n\t\t\t<td>85.8</td> <td>54.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1mLOkILfIf8Wegi3tva9390TRIbWDxRor/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.8</td> <td>77.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1KJzXwCsbyT0uNDl925H2J0QKXhKaeEgW/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t</table>\n</details>\n\n\n<details>\n     <summary><b>Pretrained models with different mining methods</b></summary><br/>\n    Pretained networks trained using three different mining methods (random, full database mining and partial database mining):</br></br>\n\t<table>\n\t\t<tr>\n\t\t\t<th rowspan=2>Model</th>\n\t\t\t <th colspan=\"3\">Training on Pitts30k (R@1)</th>\n\t\t\t <th colspan=\"3\">Training on MSLS (R@1)</th>\n\t \t</tr>\n\t \t<tr>\n\t  \t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t\t\t<td>Pitts30k (R@1)</td>\n\t   \t\t<td>MSLS (R@1)</td>\n\t   \t\t<td>Download</td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td> resnet18-gem-random</td>\n\t\t\t<td>73.7</td> <td>30.5</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/12Ds-LcvFcA609bZVBTLNjAZIzV-g8UGK/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>62.2</td> <td>50.6</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1oNZyfjTaulVTFX4wRrj0YISqxLuNRyhy/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td> resnet18-gem-full</td>\n\t\t\t<td>77.8</td> <td>35.3</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1bHVsnb6Km2npBsGK9ylI1vuOuc3WLKJb/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>70.1</td><td>61.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1BbANLPVPxWDau2RP0cWTSS3FybbyUPL1/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td> resnet18-gem-partial</td>\n\t\t\t<td>76.5</td> <td>34.2</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1R66NYeLlxBIqLviUVL9XPZkrtmyMn_tU/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>71.6</td> <td>65.3</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1IH0d_ME2kU3pagsKhx5ZfRfyWriErajn/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t\t<tr>\n\t\t\t<td> resnet18-netvlad-random</td>\n\t\t\t<td>83.9</td> <td>43.6</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/19OcEe2ckk-D8drrmxpKkkarT_5mCkjnt/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>73.3</td> <td>61.5</td>\n\t \t\t<td><a href=\"https://drive.google.com/file/d/1JlEbKbnWyCbR4zP1ZYDct3pYtuJrUmVp/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> resnet18-netvlad-full</td>\n\t\t\t<td>86.4</td> <td>47.4</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1kwgyDEfRYtdaOEimQQlmj77rIR2tH3st/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>-</td><td>-</td>\n\t\t\t<td>-</td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> resnet18-netvlad-partial</td>\n\t\t\t<td>86.2</td> <td>47.3</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1KFwonDQYdvzTAIILsOMjmLRUR76jXXvB/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>81.6</td> <td>75.8</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1_Ozq2TdvwLAJUwy7YH9l69GsfOU-MlFZ/view?usp=sharing\">[Link]</a></td>\n\t \t</tr>\n\t \t<tr>\n\t\t\t<td> resnet50-gem-random</td>\n\t\t\t<td>77.9</td> <td>34.3</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1f9be75EaG0fFLeNF0bufSre_efKH_ObU/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>69.5</td> <td>57.4</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1h9-av6qMn-LVapI5KA4cZhT5BKaZ79C6/view?usp=sharing\">[Link]</a></td>\n\t\t</tr>\n\t\t<tr>\n\t\t\t<td> resnet50-gem-full</td>\n\t\t\t<td>82.0</td> <td>38.0</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1quS9ZjOrXBqNDBhQzlSj8aeh3dBfP1GY/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>77.3</td> <td>69.7</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1pxU881eTcz_YdQthKz5yohU7WoLpXt8J/view?usp=sharing\">[Link]</a></td>\n\t\t</tr>\n\t\t<tr>\n\t\t\t<td> resnet50-gem-partial</td>\n\t\t\t<td>82.3</td> <td>39.0</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1esgXzRFvDFHrMnwwR3GlTnErXjFNrYV7/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>77.4</td> <td>72.0</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1uuIYJN4N7lQqqsN32pbZwjhz5Xvv3zr-/view?usp=sharing\">[Link]</a></td>\n\t\t</tr>\n\t\t<tr>\n\t\t\t<td> resnet50-netvlad-random</td>\n\t\t\t<td>83.4</td> <td>45.0</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1TkzlO-ZS42u6e783y2O3JZhcIoI7CEVj/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>74.9</td> <td>63.6</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1E_X2nrnLxBqvLfVfNKtorOGW_VmwOqSu/view?usp=sharing\">[Link]</a></td>\n\t\t</tr>\n\t\t<tr>\n\t\t\t<td> resnet50-netvlad-full</td>\n\t\t\t<td>86.0</td> <td>50.7</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/133uxEJZ0gK6XL1myhSAFC7wibZtWnugK/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>-</td><td>-</td>\n\t\t\t<td>-</td>\n\t\t</tr>\n\t\t<tr>\n\t\t\t<td> resnet50-netvlad-partial</td>\n\t\t\t<td>85.5</td> <td>48.6</td> \n\t\t\t<td><a href=\"https://drive.google.com/file/d/1GCbE4gzcRXMH8ETD2YCPo0I3suAXDr-y/view?usp=sharing\">[Link]</a></td>\n\t\t\t<td>80.9</td> <td>76.9</td>\n\t\t\t<td><a href=\"https://drive.google.com/file/d/1krf0A6CeW8GqLqHWZ7dlSNJ9aTJ4dotF/view?usp=sharing\">[Link]</a></td>\n\t\t</tr>\n\t</table>\n</details>\n\n\nIf you find our work useful in your research please consider citing our paper:\n```bibtex\n@inproceedings{Berton_CVPR_2022_benchmark,\n    author    = {Berton, Gabriele and Mereu, Riccardo and Trivigno, Gabriele and Masone, Carlo and Csurka, Gabriela and Sattler, Torsten and Caputo, Barbara},\n    title     = {Deep Visual Geo-Localization Benchmark},\n    booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n    month     = {June},\n    year      = {2022}\n}\n```\n\n\n## Acknowledgements\nParts of this repo are inspired by the following great repositories:\n- [NetVLAD's original code](https://github.com/Relja/netvlad) (in MATLAB)\n- [NetVLAD layer in PyTorch](https://github.com/lyakaap/NetVLAD-pytorch)\n- [NetVLAD training in PyTorch](https://github.com/Nanne/pytorch-NetVlad/)\n- [GeM layer](https://github.com/filipradenovic/cnnimageretrieval-pytorch)\n- [Deep Image Retrieval](https://github.com/naver/deep-image-retrieval)\n- [Mapillary Street-level Sequences](https://github.com/mapillary/mapillary_sls)\n- [Compact Convolutional Transformers](https://github.com/SHI-Labs/Compact-Transformers)\n\nCheck out also our other repo [_CosPlace_](https://github.com/gmberton/CosPlace), from the CVPR 2022 paper \"Rethinking Visual Geo-localization for Large-Scale Applications\", which provides a new SOTA in visual geo-localization / visual place recognition.\n"
  },
  {
    "path": "commons.py",
    "content": "\n\"\"\"\nThis file contains some functions and classes which can be useful in very diverse projects.\n\"\"\"\n\nimport os\nimport sys\nimport torch\nimport random\nimport logging\nimport traceback\nimport numpy as np\nfrom os.path import join\n\n\ndef make_deterministic(seed=0):\n    \"\"\"Make results deterministic. If seed == -1, do not make deterministic.\n    Running the script in a deterministic way might slow it down.\n    \"\"\"\n    if seed == -1:\n        return\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\ndef setup_logging(save_dir, console=\"debug\",\n                  info_filename=\"info.log\", debug_filename=\"debug.log\"):\n    \"\"\"Set up logging files and console output.\n    Creates one file for INFO logs and one for DEBUG logs.\n    Args:\n        save_dir (str): creates the folder where to save the files.\n        debug (str):\n            if == \"debug\" prints on console debug messages and higher\n            if == \"info\"  prints on console info messages and higher\n            if == None does not use console (useful when a logger has already been set)\n        info_filename (str): the name of the info file. if None, don't create info file\n        debug_filename (str): the name of the debug file. if None, don't create debug file\n    \"\"\"\n    if os.path.exists(save_dir):\n        raise FileExistsError(f\"{save_dir} already exists!\")\n    os.makedirs(save_dir, exist_ok=True)\n    # logging.Logger.manager.loggerDict.keys() to check which loggers are in use\n    base_formatter = logging.Formatter('%(asctime)s   %(message)s', \"%Y-%m-%d %H:%M:%S\")\n    logger = logging.getLogger('')\n    logger.setLevel(logging.DEBUG)\n    \n    if info_filename is not None:\n        info_file_handler = logging.FileHandler(join(save_dir, info_filename))\n        info_file_handler.setLevel(logging.INFO)\n        info_file_handler.setFormatter(base_formatter)\n        logger.addHandler(info_file_handler)\n    \n    if debug_filename is not None:\n        debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))\n        debug_file_handler.setLevel(logging.DEBUG)\n        debug_file_handler.setFormatter(base_formatter)\n        logger.addHandler(debug_file_handler)\n    \n    if console is not None:\n        console_handler = logging.StreamHandler()\n        if console == \"debug\":\n            console_handler.setLevel(logging.DEBUG)\n        if console == \"info\":\n            console_handler.setLevel(logging.INFO)\n        console_handler.setFormatter(base_formatter)\n        logger.addHandler(console_handler)\n    \n    def exception_handler(type_, value, tb):\n        logger.info(\"\\n\" + \"\".join(traceback.format_exception(type, value, tb)))\n    sys.excepthook = exception_handler\n"
  },
  {
    "path": "datasets_ws.py",
    "content": "\nimport os\nimport torch\nimport faiss\nimport logging\nimport numpy as np\nfrom glob import glob\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom os.path import join\nimport torch.utils.data as data\nimport torchvision.transforms as T\nfrom torch.utils.data.dataset import Subset\nfrom sklearn.neighbors import NearestNeighbors\nfrom torch.utils.data.dataloader import DataLoader\n\n\nbase_transform = T.Compose([\n    T.ToTensor(),\n    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n])\n\n\ndef path_to_pil_img(path):\n    return Image.open(path).convert(\"RGB\")\n\n\ndef collate_fn(batch):\n    \"\"\"Creates mini-batch tensors from the list of tuples (images,\n        triplets_local_indexes, triplets_global_indexes).\n        triplets_local_indexes are the indexes referring to each triplet within images.\n        triplets_global_indexes are the global indexes of each image.\n    Args:\n        batch: list of tuple (images, triplets_local_indexes, triplets_global_indexes).\n            considering each query to have 10 negatives (negs_num_per_query=10):\n            - images: torch tensor of shape (12, 3, h, w).\n            - triplets_local_indexes: torch tensor of shape (10, 3).\n            - triplets_global_indexes: torch tensor of shape (12).\n    Returns:\n        images: torch tensor of shape (batch_size*12, 3, h, w).\n        triplets_local_indexes: torch tensor of shape (batch_size*10, 3).\n        triplets_global_indexes: torch tensor of shape (batch_size, 12).\n    \"\"\"\n    images = torch.cat([e[0] for e in batch])\n    triplets_local_indexes = torch.cat([e[1][None] for e in batch])\n    triplets_global_indexes = torch.cat([e[2][None] for e in batch])\n    for i, (local_indexes, global_indexes) in enumerate(zip(triplets_local_indexes, triplets_global_indexes)):\n        local_indexes += len(global_indexes) * i  # Increment local indexes by offset (len(global_indexes) is 12)\n    return images, torch.cat(tuple(triplets_local_indexes)), triplets_global_indexes\n\n\nclass PCADataset(data.Dataset):\n    def __init__(self, args, datasets_folder=\"dataset\", dataset_folder=\"pitts30k/images/train\"):\n        dataset_folder_full_path = join(datasets_folder, dataset_folder)\n        if not os.path.exists(dataset_folder_full_path):\n            raise FileNotFoundError(f\"Folder {dataset_folder_full_path} does not exist\")\n        self.images_paths = sorted(glob(join(dataset_folder_full_path, \"**\", \"*.jpg\"), recursive=True))\n    \n    def __getitem__(self, index):\n        return base_transform(path_to_pil_img(self.images_paths[index]))\n    \n    def __len__(self):\n        return len(self.images_paths)\n\n\nclass BaseDataset(data.Dataset):\n    \"\"\"Dataset with images from database and queries, used for inference (testing and building cache).\n    \"\"\"\n    def __init__(self, args, datasets_folder=\"datasets\", dataset_name=\"pitts30k\", split=\"train\"):\n        super().__init__()\n        self.args = args\n        self.dataset_name = dataset_name\n        self.dataset_folder = join(datasets_folder, dataset_name, \"images\", split)\n        if not os.path.exists(self.dataset_folder):\n            raise FileNotFoundError(f\"Folder {self.dataset_folder} does not exist\")\n        \n        self.resize = args.resize\n        self.test_method = args.test_method\n        \n        #### Read paths and UTM coordinates for all images.\n        database_folder = join(self.dataset_folder, \"database\")\n        queries_folder = join(self.dataset_folder, \"queries\")\n        if not os.path.exists(database_folder):\n            raise FileNotFoundError(f\"Folder {database_folder} does not exist\")\n        if not os.path.exists(queries_folder):\n            raise FileNotFoundError(f\"Folder {queries_folder} does not exist\")\n        self.database_paths = sorted(glob(join(database_folder, \"**\", \"*.jpg\"), recursive=True))\n        self.queries_paths = sorted(glob(join(queries_folder, \"**\", \"*.jpg\"),  recursive=True))\n        # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg\n        self.database_utms = np.array([(path.split(\"@\")[1], path.split(\"@\")[2]) for path in self.database_paths]).astype(float)\n        self.queries_utms = np.array([(path.split(\"@\")[1], path.split(\"@\")[2]) for path in self.queries_paths]).astype(float)\n        \n        # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters)\n        knn = NearestNeighbors(n_jobs=-1)\n        knn.fit(self.database_utms)\n        self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms,\n                                                             radius=args.val_positive_dist_threshold,\n                                                             return_distance=False)\n        \n        self.images_paths = list(self.database_paths) + list(self.queries_paths)\n        \n        self.database_num = len(self.database_paths)\n        self.queries_num = len(self.queries_paths)\n    \n    def __getitem__(self, index):\n        img = path_to_pil_img(self.images_paths[index])\n        img = base_transform(img)\n        # With database images self.test_method should always be \"hard_resize\"\n        if self.test_method == \"hard_resize\":\n            # self.test_method==\"hard_resize\" is the default, resizes all images to the same size.\n            img = T.functional.resize(img, self.resize)\n        else:\n            img = self._test_query_transform(img)\n        return img, index\n    \n    def _test_query_transform(self, img):\n        \"\"\"Transform query image according to self.test_method.\"\"\"\n        C, H, W = img.shape\n        if self.test_method == \"single_query\":\n            # self.test_method==\"single_query\" is used when queries have varying sizes, and can't be stacked in a batch.\n            processed_img = T.functional.resize(img, min(self.resize))\n        elif self.test_method == \"central_crop\":\n            # Take the biggest central crop of size self.resize. Preserves ratio.\n            scale = max(self.resize[0]/H, self.resize[1]/W)\n            processed_img = torch.nn.functional.interpolate(img.unsqueeze(0), scale_factor=scale).squeeze(0)\n            processed_img = T.functional.center_crop(processed_img, self.resize)\n            assert processed_img.shape[1:] == torch.Size(self.resize), f\"{processed_img.shape[1:]} {self.resize}\"\n        elif self.test_method == \"five_crops\" or self.test_method == 'nearest_crop' or self.test_method == 'maj_voting':\n            # Get 5 square crops with size==shorter_side (usually 480). Preserves ratio and allows batches.\n            shorter_side = min(self.resize)\n            processed_img = T.functional.resize(img, shorter_side)\n            processed_img = torch.stack(T.functional.five_crop(processed_img, shorter_side))\n            assert processed_img.shape == torch.Size([5, 3, shorter_side, shorter_side]), \\\n                f\"{processed_img.shape} {torch.Size([5, 3, shorter_side, shorter_side])}\"\n        return processed_img\n    \n    def __len__(self):\n        return len(self.images_paths)\n    \n    def __repr__(self):\n        return f\"< {self.__class__.__name__}, {self.dataset_name} - #database: {self.database_num}; #queries: {self.queries_num} >\"\n    \n    def get_positives(self):\n        return self.soft_positives_per_query\n\n\nclass TripletsDataset(BaseDataset):\n    \"\"\"Dataset used for training, it is used to compute the triplets\n    with TripletsDataset.compute_triplets() with various mining methods.\n    If is_inference == True, uses methods of the parent class BaseDataset,\n    this is used for example when computing the cache, because we compute features\n    of each image, not triplets.\n    \"\"\"\n    def __init__(self, args, datasets_folder=\"datasets\", dataset_name=\"pitts30k\", split=\"train\", negs_num_per_query=10):\n        super().__init__(args, datasets_folder, dataset_name, split)\n        self.mining = args.mining\n        self.neg_samples_num = args.neg_samples_num  # Number of negatives to randomly sample\n        self.negs_num_per_query = negs_num_per_query  # Number of negatives per query in each batch\n        if self.mining == \"full\":  # \"Full database mining\" keeps a cache with last used negatives\n            self.neg_cache = [np.empty((0,), dtype=np.int32) for _ in range(self.queries_num)]\n        self.is_inference = False\n        \n        identity_transform = T.Lambda(lambda x: x)\n        self.resized_transform = T.Compose([\n            T.Resize(self.resize) if self.resize is not None else identity_transform,\n            base_transform\n        ])\n        \n        self.query_transform = T.Compose([\n                T.ColorJitter(args.brightness, args.contrast, args.saturation, args.hue),\n                T.RandomPerspective(args.rand_perspective),\n                T.RandomResizedCrop(size=self.resize, scale=(1-args.random_resized_crop, 1)),\n                T.RandomRotation(degrees=args.random_rotation),\n                self.resized_transform,\n        ])\n        \n        # Find hard_positives_per_query, which are within train_positives_dist_threshold (10 meters)\n        knn = NearestNeighbors(n_jobs=-1)\n        knn.fit(self.database_utms)\n        self.hard_positives_per_query = list(knn.radius_neighbors(self.queries_utms,\n                                             radius=args.train_positives_dist_threshold,  # 10 meters\n                                             return_distance=False))\n        \n        #### Some queries might have no positive, we should remove those queries.\n        queries_without_any_hard_positive = np.where(np.array([len(p) for p in self.hard_positives_per_query], dtype=object) == 0)[0]\n        if len(queries_without_any_hard_positive) != 0:\n            logging.info(f\"There are {len(queries_without_any_hard_positive)} queries without any positives \" +\n                         \"within the training set. They won't be considered as they're useless for training.\")\n        # Remove queries without positives\n        self.hard_positives_per_query = np.delete(self.hard_positives_per_query, queries_without_any_hard_positive)\n        self.soft_positives_per_query = np.delete(self.soft_positives_per_query, queries_without_any_hard_positive)\n        self.queries_paths = np.delete(self.queries_paths, queries_without_any_hard_positive)\n        \n        # Recompute images_paths and queries_num because some queries might have been removed\n        self.images_paths = list(self.database_paths) + list(self.queries_paths)\n        self.queries_num = len(self.queries_paths)\n        \n        # msls_weighted refers to the mining presented in MSLS paper's supplementary.\n        # Basically, images from uncommon domains are sampled more often. Works only with MSLS dataset.\n        if self.mining == \"msls_weighted\":\n            notes = [p.split(\"@\")[-2] for p in self.queries_paths]\n            try:\n                night_indexes = np.where(np.array([n.split(\"_\")[0] == \"night\" for n in notes]))[0]\n                sideways_indexes = np.where(np.array([n.split(\"_\")[1] == \"sideways\" for n in notes]))[0]\n            except IndexError:\n                raise RuntimeError(\"You're using msls_weighted mining but this dataset \" +\n                                   \"does not have night/sideways information. Are you using Mapillary SLS?\")\n            self.weights = np.ones(self.queries_num)\n            assert len(night_indexes) != 0 and len(sideways_indexes) != 0, \\\n                \"There should be night and sideways images for msls_weighted mining, but there are none. Are you using Mapillary SLS?\"\n            self.weights[night_indexes] += self.queries_num / len(night_indexes)\n            self.weights[sideways_indexes] += self.queries_num / len(sideways_indexes)\n            self.weights /= self.weights.sum()\n            logging.info(f\"#sideways_indexes [{len(sideways_indexes)}/{self.queries_num}]; \" +\n                         \"#night_indexes; [{len(night_indexes)}/{self.queries_num}]\")\n    \n    def __getitem__(self, index):\n        if self.is_inference:\n            # At inference time return the single image. This is used for caching or computing NetVLAD's clusters\n            return super().__getitem__(index)\n        query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index], (1, 1, self.negs_num_per_query))\n        query = self.query_transform(path_to_pil_img(self.queries_paths[query_index]))\n        positive = self.resized_transform(path_to_pil_img(self.database_paths[best_positive_index]))\n        negatives = [self.resized_transform(path_to_pil_img(self.database_paths[i])) for i in neg_indexes]\n        images = torch.stack((query, positive, *negatives), 0)\n        triplets_local_indexes = torch.empty((0, 3), dtype=torch.int)\n        for neg_num in range(len(neg_indexes)):\n            triplets_local_indexes = torch.cat((triplets_local_indexes, torch.tensor([0, 1, 2 + neg_num]).reshape(1, 3)))\n        return images, triplets_local_indexes, self.triplets_global_indexes[index]\n    \n    def __len__(self):\n        if self.is_inference:\n            # At inference time return the number of images. This is used for caching or computing NetVLAD's clusters\n            return super().__len__()\n        else:\n            return len(self.triplets_global_indexes)\n    \n    def compute_triplets(self, args, model):\n        self.is_inference = True\n        if self.mining == \"full\":\n            self.compute_triplets_full(args, model)\n        elif self.mining == \"partial\" or self.mining == \"msls_weighted\":\n            self.compute_triplets_partial(args, model)\n        elif self.mining == \"random\":\n            self.compute_triplets_random(args, model)\n    \n    @staticmethod\n    def compute_cache(args, model, subset_ds, cache_shape):\n        \"\"\"Compute the cache containing features of images, which is used to\n        find best positive and hardest negatives.\"\"\"\n        subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers,\n                               batch_size=args.infer_batch_size, shuffle=False,\n                               pin_memory=(args.device == \"cuda\"))\n        model = model.eval()\n        \n        # RAMEfficient2DMatrix can be replaced by np.zeros, but using\n        # RAMEfficient2DMatrix is RAM efficient for full database mining.\n        cache = RAMEfficient2DMatrix(cache_shape, dtype=np.float32)\n        with torch.no_grad():\n            for images, indexes in tqdm(subset_dl, ncols=100):\n                images = images.to(args.device)\n                features = model(images)\n                cache[indexes.numpy()] = features.cpu().numpy()\n        return cache\n    \n    def get_query_features(self, query_index, cache):\n        query_features = cache[query_index + self.database_num]\n        if query_features is None:\n            raise RuntimeError(f\"For query {self.queries_paths[query_index]} \" +\n                               f\"with index {query_index} features have not been computed!\\n\" +\n                               \"There might be some bug with caching\")\n        return query_features\n    \n    def get_best_positive_index(self, args, query_index, cache, query_features):\n        positives_features = cache[self.hard_positives_per_query[query_index]]\n        faiss_index = faiss.IndexFlatL2(args.features_dim)\n        faiss_index.add(positives_features)\n        # Search the best positive (within 10 meters AND nearest in features space)\n        _, best_positive_num = faiss_index.search(query_features.reshape(1, -1), 1)\n        best_positive_index = self.hard_positives_per_query[query_index][best_positive_num[0]].item()\n        return best_positive_index\n    \n    def get_hardest_negatives_indexes(self, args, cache, query_features, neg_samples):\n        neg_features = cache[neg_samples]\n        faiss_index = faiss.IndexFlatL2(args.features_dim)\n        faiss_index.add(neg_features)\n        # Search the 10 nearest negatives (further than 25 meters and nearest in features space)\n        _, neg_nums = faiss_index.search(query_features.reshape(1, -1), self.negs_num_per_query)\n        neg_nums = neg_nums.reshape(-1)\n        neg_indexes = neg_samples[neg_nums].astype(np.int32)\n        return neg_indexes\n    \n    def compute_triplets_random(self, args, model):\n        self.triplets_global_indexes = []\n        # Take 1000 random queries\n        sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)\n        # Take all the positives\n        positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes]\n        positives_indexes = [p for pos in positives_indexes for p in pos]  # Flatten list of lists to a list\n        positives_indexes = list(np.unique(positives_indexes))\n        \n        # Compute the cache only for queries and their positives, in order to find the best positive\n        subset_ds = Subset(self, positives_indexes + list(sampled_queries_indexes + self.database_num))\n        cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim))\n        \n        # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)\n        for query_index in tqdm(sampled_queries_indexes, ncols=100):\n            query_features = self.get_query_features(query_index, cache)\n            best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)\n            \n            # Choose some random database images, from those remove the soft_positives, and then take the first 10 images as neg_indexes\n            soft_positives = self.soft_positives_per_query[query_index]\n            neg_indexes = np.random.choice(self.database_num, size=self.negs_num_per_query+len(soft_positives), replace=False)\n            neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True)[:self.negs_num_per_query]\n            \n            self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))\n        # self.triplets_global_indexes is a tensor of shape [1000, 12]\n        self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)\n    \n    def compute_triplets_full(self, args, model):\n        self.triplets_global_indexes = []\n        # Take 1000 random queries\n        sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)\n        # Take all database indexes\n        database_indexes = list(range(self.database_num))\n        #  Compute features for all images and store them in cache\n        subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num))\n        cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim))\n        \n        # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)\n        for query_index in tqdm(sampled_queries_indexes, ncols=100):\n            query_features = self.get_query_features(query_index, cache)\n            best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)\n            # Choose 1000 random database images (neg_indexes)\n            neg_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False)\n            # Remove the eventual soft_positives from neg_indexes\n            soft_positives = self.soft_positives_per_query[query_index]\n            neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True)\n            # Concatenate neg_indexes with the previous top 10 negatives (neg_cache)\n            neg_indexes = np.unique(np.concatenate([self.neg_cache[query_index], neg_indexes]))\n            # Search the hardest negatives\n            neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes)\n            # Update nearest negatives in neg_cache\n            self.neg_cache[query_index] = neg_indexes\n            self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))\n        # self.triplets_global_indexes is a tensor of shape [1000, 12]\n        self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)\n    \n    def compute_triplets_partial(self, args, model):\n        self.triplets_global_indexes = []\n        # Take 1000 random queries\n        if self.mining == \"partial\":\n            sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)\n        elif self.mining == \"msls_weighted\":  # Pick night and sideways queries with higher probability\n            sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False, p=self.weights)\n        \n        # Sample 1000 random database images for the negatives\n        sampled_database_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False)\n        # Take all the positives\n        positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes]\n        positives_indexes = [p for pos in positives_indexes for p in pos]\n        # Merge them into database_indexes and remove duplicates\n        database_indexes = list(sampled_database_indexes) + positives_indexes\n        database_indexes = list(np.unique(database_indexes))\n        \n        subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num))\n        cache = self.compute_cache(args, model, subset_ds, cache_shape=(len(self), args.features_dim))\n        \n        # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)\n        for query_index in tqdm(sampled_queries_indexes, ncols=100):\n            query_features = self.get_query_features(query_index, cache)\n            best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)\n            \n            # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives\n            soft_positives = self.soft_positives_per_query[query_index]\n            neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True)\n            \n            # Take all database images that are negatives and are within the sampled database images (aka database_indexes)\n            neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes)\n            self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))\n        # self.triplets_global_indexes is a tensor of shape [1000, 12]\n        self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)\n\n\nclass RAMEfficient2DMatrix:\n    \"\"\"This class behaves similarly to a numpy.ndarray initialized\n    with np.zeros(), but is implemented to save RAM when the rows\n    within the 2D array are sparse. In this case it's needed because\n    we don't always compute features for each image, just for few of\n    them\"\"\"\n    def __init__(self, shape, dtype=np.float32):\n        self.shape = shape\n        self.dtype = dtype\n        self.matrix = [None] * shape[0]\n    \n    def __setitem__(self, indexes, vals):\n        assert vals.shape[1] == self.shape[1], f\"{vals.shape[1]} {self.shape[1]}\"\n        for i, val in zip(indexes, vals):\n            self.matrix[i] = val.astype(self.dtype, copy=False)\n    \n    def __getitem__(self, index):\n        if hasattr(index, \"__len__\"):\n            return np.array([self.matrix[i] for i in index])\n        else:\n            return self.matrix[index]\n"
  },
  {
    "path": "eval.py",
    "content": "\n\"\"\"\nWith this script you can evaluate checkpoints or test models from two popular\nlandmark retrieval github repos.\nThe first is https://github.com/naver/deep-image-retrieval from Naver labs,\nprovides ResNet-50 and ResNet-101 trained with AP on Google Landmarks 18 clean.\n$ python eval.py --off_the_shelf=naver --l2=none --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048\n\nThe second is https://github.com/filipradenovic/cnnimageretrieval-pytorch from\nRadenovic, provides ResNet-50 and ResNet-101 trained with a triplet loss\non Google Landmarks 18 and sfm120k.\n$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048\n$ python eval.py --off_the_shelf=radenovic_sfm --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048\n\nNote that although the architectures are almost the same, Naver's\nimplementation does not use a l2 normalization before/after the GeM aggregation,\nwhile Radenovic's uses it after (and we use it before, which shows better\nresults in VG)\n\"\"\"\n\nimport os\nimport sys\nimport torch\nimport parser\nimport logging\nimport sklearn\nfrom os.path import join\nfrom datetime import datetime\nfrom torch.utils.model_zoo import load_url\nfrom google_drive_downloader import GoogleDriveDownloader as gdd\n\nimport test\nimport util\nimport commons\nimport datasets_ws\nfrom model import network\n\nOFF_THE_SHELF_RADENOVIC = {\n    'resnet50conv5_sfm'    : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',\n    'resnet101conv5_sfm'   : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',\n    'resnet50conv5_gldv1'  : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',\n    'resnet101conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',\n}\n\nOFF_THE_SHELF_NAVER = {\n    \"resnet50conv5\"  : \"1oPtE_go9tnsiDLkWjN4NMpKjh-_md1G5\",\n    'resnet101conv5' : \"1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy\"\n}\n\n######################################### SETUP #########################################\nargs = parser.parse_arguments()\nstart_time = datetime.now()\nargs.save_dir = join(\"test\", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S'))\ncommons.setup_logging(args.save_dir)\ncommons.make_deterministic(args.seed)\nlogging.info(f\"Arguments: {args}\")\nlogging.info(f\"The outputs are being saved in {args.save_dir}\")\n\n######################################### MODEL #########################################\nmodel = network.GeoLocalizationNet(args)\nmodel = model.to(args.device)\n\nif args.aggregation in [\"netvlad\", \"crn\"]:\n    args.features_dim *= args.netvlad_clusters\n\nif args.off_the_shelf.startswith(\"radenovic\") or args.off_the_shelf.startswith(\"naver\"):\n    if args.off_the_shelf.startswith(\"radenovic\"):\n        pretrain_dataset_name = args.off_the_shelf.split(\"_\")[1]  # sfm or gldv1 datasets\n        url = OFF_THE_SHELF_RADENOVIC[f\"{args.backbone}_{pretrain_dataset_name}\"]\n        state_dict = load_url(url, model_dir=join(\"data\", \"off_the_shelf_nets\"))\n    else:\n        # This is a hacky workaround to maintain compatibility\n        sys.modules['sklearn.decomposition.pca'] = sklearn.decomposition._pca\n        zip_file_path = join(\"data\", \"off_the_shelf_nets\", args.backbone + \"_naver.zip\")\n        if not os.path.exists(zip_file_path):\n            gdd.download_file_from_google_drive(file_id=OFF_THE_SHELF_NAVER[args.backbone],\n                                                dest_path=zip_file_path, unzip=True)\n        if args.backbone == \"resnet50conv5\":\n            state_dict_filename = \"Resnet50-AP-GeM.pt\"\n        elif args.backbone == \"resnet101conv5\":\n            state_dict_filename = \"Resnet-101-AP-GeM.pt\"\n        state_dict = torch.load(join(\"data\", \"off_the_shelf_nets\", state_dict_filename))\n    state_dict = state_dict[\"state_dict\"]\n    model_keys = model.state_dict().keys()\n    renamed_state_dict = {k: v for k, v in zip(model_keys, state_dict.values())}\n    model.load_state_dict(renamed_state_dict)\nelif args.resume is not None:\n    logging.info(f\"Resuming model from {args.resume}\")\n    model = util.resume_model(args, model)\n# Enable DataParallel after loading checkpoint, otherwise doing it before\n# would append \"module.\" in front of the keys of the state dict triggering errors\nmodel = torch.nn.DataParallel(model)\n\nif args.pca_dim is None:\n    pca = None\nelse:\n    full_features_dim = args.features_dim\n    args.features_dim = args.pca_dim\n    pca = util.compute_pca(args, model, args.pca_dataset_folder, full_features_dim)\n\n######################################### DATASETS #########################################\ntest_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, \"test\")\nlogging.info(f\"Test set: {test_ds}\")\n\n######################################### TEST on TEST SET #########################################\nrecalls, recalls_str = test.test(args, test_ds, model, args.test_method, pca)\nlogging.info(f\"Recalls on {test_ds}: {recalls_str}\")\n\nlogging.info(f\"Finished in {str(datetime.now() - start_time)[:-7]}\")\n"
  },
  {
    "path": "model/__init__.py",
    "content": ""
  },
  {
    "path": "model/aggregation.py",
    "content": "\nimport math\nimport torch\nimport faiss\nimport logging\nimport numpy as np\nfrom tqdm import tqdm\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.parameter import Parameter\nfrom torch.utils.data import DataLoader, SubsetRandomSampler\n\nimport model.functional as LF\nimport model.normalization as normalization\n\nclass MAC(nn.Module):\n    def __init__(self):\n        super().__init__()\n    def forward(self, x):\n        return LF.mac(x)\n    def __repr__(self):\n        return self.__class__.__name__ + '()'\n\nclass SPoC(nn.Module):\n    def __init__(self):\n        super().__init__()\n    def forward(self, x):\n        return LF.spoc(x)\n    def __repr__(self):\n        return self.__class__.__name__ + '()'\n\nclass GeM(nn.Module):\n    def __init__(self, p=3, eps=1e-6, work_with_tokens=False):\n        super().__init__()\n        self.p = Parameter(torch.ones(1)*p)\n        self.eps = eps\n        self.work_with_tokens=work_with_tokens\n    def forward(self, x):\n        return LF.gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens)\n    def __repr__(self):\n        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'\n\nclass RMAC(nn.Module):\n    def __init__(self, L=3, eps=1e-6):\n        super().__init__()\n        self.L = L\n        self.eps = eps\n    def forward(self, x):\n        return LF.rmac(x, L=self.L, eps=self.eps)\n    def __repr__(self):\n        return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'\n\n\nclass Flatten(torch.nn.Module):\n    def __init__(self): super().__init__()\n    def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0]\n\nclass RRM(nn.Module):\n    \"\"\"Residual Retrieval Module as described in the paper \n    `Leveraging EfficientNet and Contrastive Learning for AccurateGlobal-scale \n    Location Estimation <https://arxiv.org/pdf/2105.07645.pdf>`\n    \"\"\"\n    def __init__(self, dim):\n        super().__init__()\n        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)\n        self.flatten = Flatten()\n        self.ln1 = nn.LayerNorm(normalized_shape=dim)\n        self.fc1 = nn.Linear(in_features=dim, out_features=dim)\n        self.relu = nn.ReLU()\n        self.fc2 = nn.Linear(in_features=dim, out_features=dim)\n        self.ln2 = nn.LayerNorm(normalized_shape=dim)\n        self.l2 = normalization.L2Norm()\n    def forward(self, x):\n        x = self.avgpool(x)\n        x = self.flatten(x)\n        x = self.ln1(x)\n        identity = x\n        out = self.fc2(self.relu(self.fc1(x)))\n        out += identity\n        out = self.l2(self.ln2(out))\n        return out\n\n\n# based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py\nclass NetVLAD(nn.Module):\n    \"\"\"NetVLAD layer implementation\"\"\"\n\n    def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False):\n        \"\"\"\n        Args:\n            clusters_num : int\n                The number of clusters\n            dim : int\n                Dimension of descriptors\n            alpha : float\n                Parameter of initialization. Larger value is harder assignment.\n            normalize_input : bool\n                If true, descriptor-wise L2 normalization is applied to input.\n        \"\"\"\n        super().__init__()\n        self.clusters_num = clusters_num\n        self.dim = dim\n        self.alpha = 0\n        self.normalize_input = normalize_input\n        self.work_with_tokens = work_with_tokens\n        if work_with_tokens:\n            self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False)\n        else:\n            self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False)\n        self.centroids = nn.Parameter(torch.rand(clusters_num, dim))\n\n    def init_params(self, centroids, descriptors):\n        centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True)\n        dots = np.dot(centroids_assign, descriptors.T)\n        dots.sort(0)\n        dots = dots[::-1, :]  # sort, descending\n\n        self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item()\n        self.centroids = nn.Parameter(torch.from_numpy(centroids))\n        if self.work_with_tokens:\n            self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2))\n        else:\n            self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3))\n        self.conv.bias = None\n\n    def forward(self, x):\n        if self.work_with_tokens:\n            x = x.permute(0, 2, 1)\n            N, D, _ = x.shape[:]\n        else:\n            N, D, H, W = x.shape[:]\n        if self.normalize_input:\n            x = F.normalize(x, p=2, dim=1)  # Across descriptor dim\n        x_flatten = x.view(N, D, -1)\n        soft_assign = self.conv(x).view(N, self.clusters_num, -1)\n        soft_assign = F.softmax(soft_assign, dim=1)\n        vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)\n        for D in range(self.clusters_num):  # Slower than non-looped, but lower memory usage\n            residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \\\n                    self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)\n            residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2)\n            vlad[:,D:D+1,:] = residual.sum(dim=-1)\n        vlad = F.normalize(vlad, p=2, dim=2)  # intra-normalization\n        vlad = vlad.view(N, -1)  # Flatten\n        vlad = F.normalize(vlad, p=2, dim=1)  # L2 normalize\n        return vlad\n\n    def initialize_netvlad_layer(self, args, cluster_ds, backbone):\n        descriptors_num = 50000\n        descs_num_per_image = 100\n        images_num = math.ceil(descriptors_num / descs_num_per_image)\n        random_sampler = SubsetRandomSampler(np.random.choice(len(cluster_ds), images_num, replace=False))\n        random_dl = DataLoader(dataset=cluster_ds, num_workers=args.num_workers,\n                                batch_size=args.infer_batch_size, sampler=random_sampler)\n        with torch.no_grad():\n            backbone = backbone.eval()\n            logging.debug(\"Extracting features to initialize NetVLAD layer\")\n            descriptors = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32)\n            for iteration, (inputs, _) in enumerate(tqdm(random_dl, ncols=100)):\n                inputs = inputs.to(args.device)\n                outputs = backbone(inputs)\n                norm_outputs = F.normalize(outputs, p=2, dim=1)\n                image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1)\n                image_descriptors = image_descriptors.cpu().numpy()\n                batchix = iteration * args.infer_batch_size * descs_num_per_image\n                for ix in range(image_descriptors.shape[0]):\n                    sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False)\n                    startix = batchix + ix * descs_num_per_image\n                    descriptors[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :]\n        kmeans = faiss.Kmeans(args.features_dim, self.clusters_num, niter=100, verbose=False)\n        kmeans.train(descriptors)\n        logging.debug(f\"NetVLAD centroids shape: {kmeans.centroids.shape}\")\n        self.init_params(kmeans.centroids, descriptors)\n        self = self.to(args.device)\n\n\nclass CRNModule(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        # Downsample pooling\n        self.downsample_pool = nn.AvgPool2d(kernel_size=3, stride=(2, 2),\n                                            padding=0, ceil_mode=True)\n        \n        # Multiscale Context Filters\n        self.filter_3_3 = nn.Conv2d(in_channels=dim, out_channels=32,\n                                    kernel_size=(3, 3), padding=1)\n        self.filter_5_5 = nn.Conv2d(in_channels=dim, out_channels=32,\n                                    kernel_size=(5, 5), padding=2)\n        self.filter_7_7 = nn.Conv2d(in_channels=dim, out_channels=20,\n                                    kernel_size=(7, 7), padding=3)\n        \n        # Accumulation weight\n        self.acc_w = nn.Conv2d(in_channels=84, out_channels=1, kernel_size=(1, 1))\n        # Upsampling\n        self.upsample = F.interpolate\n        \n        self._initialize_weights()\n    \n    def _initialize_weights(self):\n        # Initialize Context Filters\n        torch.nn.init.xavier_normal_(self.filter_3_3.weight)\n        torch.nn.init.constant_(self.filter_3_3.bias, 0.0)\n        torch.nn.init.xavier_normal_(self.filter_5_5.weight)\n        torch.nn.init.constant_(self.filter_5_5.bias, 0.0)\n        torch.nn.init.xavier_normal_(self.filter_7_7.weight)\n        torch.nn.init.constant_(self.filter_7_7.bias, 0.0)\n        \n        torch.nn.init.constant_(self.acc_w.weight, 1.0)\n        torch.nn.init.constant_(self.acc_w.bias, 0.0)\n        self.acc_w.weight.requires_grad = False\n        self.acc_w.bias.requires_grad = False\n    \n    def forward(self, x):\n        # Contextual Reweighting Network\n        x_crn = self.downsample_pool(x)\n        \n        # Compute multiscale context filters g_n\n        g_3 = self.filter_3_3(x_crn)\n        g_5 = self.filter_5_5(x_crn)\n        g_7 = self.filter_7_7(x_crn)\n        g = torch.cat((g_3, g_5, g_7), dim=1)\n        g = F.relu(g)\n        \n        w = F.relu(self.acc_w(g))  # Accumulation weight\n        mask = self.upsample(w, scale_factor=2, mode='bilinear')  # Reweighting Mask\n        \n        return mask\n\n\nclass CRN(NetVLAD):\n    def __init__(self, clusters_num=64, dim=128, normalize_input=True):\n        super().__init__(clusters_num, dim, normalize_input)\n        self.crn = CRNModule(dim)\n    \n    def forward(self, x):\n        N, D, H, W = x.shape[:]\n        if self.normalize_input:\n            x = F.normalize(x, p=2, dim=1)  # Across descriptor dim\n        \n        mask = self.crn(x)\n        \n        x_flatten = x.view(N, D, -1)\n        soft_assign = self.conv(x).view(N, self.clusters_num, -1)\n        soft_assign = F.softmax(soft_assign, dim=1)\n        \n        # Weight soft_assign using CRN's mask\n        soft_assign = soft_assign * mask.view(N, 1, H * W)\n        \n        vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)\n        for D in range(self.clusters_num):  # Slower than non-looped, but lower memory usage\n            residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \\\n                       self.centroids[D:D + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)\n            residual = residual * soft_assign[:, D:D + 1, :].unsqueeze(2)\n            vlad[:, D:D + 1, :] = residual.sum(dim=-1)\n        \n        vlad = F.normalize(vlad, p=2, dim=2)  # intra-normalization\n        vlad = vlad.view(N, -1)  # Flatten\n        vlad = F.normalize(vlad, p=2, dim=1)  # L2 normalize\n        return vlad\n\n"
  },
  {
    "path": "model/cct/__init__.py",
    "content": "from .cct import cct_14_7x2_384, cct_14_7x2_224"
  },
  {
    "path": "model/cct/cct.py",
    "content": "from torch.hub import load_state_dict_from_url\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nfrom .transformers import TransformerClassifier\nfrom .tokenizer import Tokenizer\nfrom .helpers import pe_check\n\nfrom timm.models.registry import register_model\n\n\nmodel_urls = {\n    'cct_7_3x1_32':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar10_300epochs.pth',\n    'cct_7_3x1_32_sine':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar10_5000epochs.pth',\n    'cct_7_3x1_32_c100':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar100_300epochs.pth',\n    'cct_7_3x1_32_sine_c100':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar100_5000epochs.pth',\n    'cct_7_7x2_224_sine':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_7x2_224_flowers102.pth',\n    'cct_14_7x2_224':\n        'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_14_7x2_224_imagenet.pth',\n    'cct_14_7x2_384':\n        'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_imagenet.pth',\n    'cct_14_7x2_384_fl':\n        'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_flowers102.pth',\n}\n\n\nclass CCT(nn.Module):\n    def __init__(self,\n                 img_size=224,\n                 embedding_dim=768,\n                 n_input_channels=3,\n                 n_conv_layers=1,\n                 kernel_size=7,\n                 stride=2,\n                 padding=3,\n                 pooling_kernel_size=3,\n                 pooling_stride=2,\n                 pooling_padding=1,\n                 dropout=0.,\n                 attention_dropout=0.1,\n                 stochastic_depth=0.1,\n                 num_layers=14,\n                 num_heads=6,\n                 mlp_ratio=4.0,\n                 num_classes=1000,\n                 positional_embedding='learnable',\n                 aggregation=None,\n                 *args, **kwargs):\n        super(CCT, self).__init__()\n\n        self.tokenizer = Tokenizer(n_input_channels=n_input_channels,\n                                   n_output_channels=embedding_dim,\n                                   kernel_size=kernel_size,\n                                   stride=stride,\n                                   padding=padding,\n                                   pooling_kernel_size=pooling_kernel_size,\n                                   pooling_stride=pooling_stride,\n                                   pooling_padding=pooling_padding,\n                                   max_pool=True,\n                                   activation=nn.ReLU,\n                                   n_conv_layers=n_conv_layers,\n                                   conv_bias=False)\n\n        self.classifier = TransformerClassifier(\n            sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,\n                                                           height=img_size,\n                                                           width=img_size),\n            embedding_dim=embedding_dim,\n            seq_pool=True,\n            dropout=dropout,\n            attention_dropout=attention_dropout,\n            stochastic_depth=stochastic_depth,\n            num_layers=num_layers,\n            num_heads=num_heads,\n            mlp_ratio=mlp_ratio,\n            num_classes=num_classes,\n            positional_embedding=positional_embedding\n        )\n        if aggregation in ['cls', 'seqpool']:\n            self.aggregation = aggregation\n        else:\n            self.aggregation = None\n\n    def forward(self, x):\n        x = self.tokenizer(x)\n        x = self.classifier(x)\n        if self.aggregation == 'cls':\n            return x[:, 0]\n        elif self.aggregation == 'seqpool':\n            x = torch.matmul(F.softmax(self.classifier.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)\n            return x\n        else:\n            # x = x.permute(0, 2, 1)\n            return x\n\n\ndef _cct(arch, pretrained, progress,\n         num_layers, num_heads, mlp_ratio, embedding_dim,\n         kernel_size=3, stride=None, padding=None,\n         aggregation=None, *args, **kwargs):\n    stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)\n    padding = padding if padding is not None else max(1, (kernel_size // 2))\n    model = CCT(num_layers=num_layers,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                embedding_dim=embedding_dim,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                aggregation=aggregation,\n                *args, **kwargs)\n\n    if pretrained:\n        if arch in model_urls:\n            state_dict = load_state_dict_from_url(model_urls[arch],\n                                                  progress=progress)\n            state_dict = pe_check(model, state_dict)\n            model.load_state_dict(state_dict, strict=False)\n        else:\n            raise RuntimeError(f'Variant {arch} does not yet have pretrained weights.')\n    return model\n\n\ndef cct_2(arch, pretrained, progress, aggregation=None, *args, **kwargs):\n    return _cct(arch, pretrained, progress, num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,\n                aggregation=aggregation, *args, **kwargs)\n\n\ndef cct_4(arch, pretrained, progress, aggregation=None, *args, **kwargs):\n    return _cct(arch, pretrained, progress, num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,\n                aggregation=aggregation, *args, **kwargs)\n\n\ndef cct_6(arch, pretrained, progress, aggregation=None, *args, **kwargs):\n    return _cct(arch, pretrained, progress, num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,\n                aggregation=aggregation, *args, **kwargs)\n\n\ndef cct_7(arch, pretrained, progress, aggregation=None, *args, **kwargs):\n    return _cct(arch, pretrained, progress, num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,\n                aggregation=aggregation, *args, **kwargs)\n\n\ndef cct_14(arch, pretrained, progress, aggregation=None, *args, **kwargs):\n    return _cct(arch, pretrained, progress, num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,\n                aggregation=aggregation, *args, **kwargs)\n\n\n@register_model\ndef cct_2_3x2_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                 aggregation=None, *args, **kwargs):\n    return cct_2('cct_2_3x2_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_2_3x2_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                      aggregation=None, *args, **kwargs):\n    return cct_2('cct_2_3x2_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_4_3x2_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                 aggregation=None, *args, **kwargs):\n    return cct_4('cct_4_3x2_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_4_3x2_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                     aggregation=None, *args, **kwargs):\n    return cct_4('cct_4_3x2_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_6_3x1_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                aggregation=None, *args, **kwargs):\n    return cct_6('cct_6_3x1_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_6_3x1_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                      aggregation=None, *args, **kwargs):\n    return cct_6('cct_6_3x1_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_6_3x2_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                 aggregation=None, *args, **kwargs):\n    return cct_6('cct_6_3x2_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_6_3x2_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                      aggregation=None, *args, **kwargs):\n    return cct_6('cct_6_3x2_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x1_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                 aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x1_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x1_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                      aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x1_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x1_32_c100(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='learnable', num_classes=100,\n                      aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x1_32_c100', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x1_32_sine_c100(pretrained=False, progress=False,\n                           img_size=32, positional_embedding='sine', num_classes=100,\n                           aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x1_32_sine_c100', pretrained, progress,\n                 kernel_size=3, n_conv_layers=1,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x2_32(pretrained=False, progress=False,\n                 img_size=32, positional_embedding='learnable', num_classes=10,\n                 aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x2_32', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_3x2_32_sine(pretrained=False, progress=False,\n                      img_size=32, positional_embedding='sine', num_classes=10,\n                      aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_3x2_32_sine', pretrained, progress,\n                 kernel_size=3, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_7x2_224(pretrained=False, progress=False,\n                  img_size=224, positional_embedding='learnable', num_classes=102,\n                  aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_7x2_224', pretrained, progress,\n                 kernel_size=7, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_7_7x2_224_sine(pretrained=False, progress=False,\n                       img_size=224, positional_embedding='sine', num_classes=102,\n                       aggregation=None, *args, **kwargs):\n    return cct_7('cct_7_7x2_224_sine', pretrained, progress,\n                 kernel_size=7, n_conv_layers=2,\n                 img_size=img_size, positional_embedding=positional_embedding,\n                 num_classes=num_classes, aggregation=aggregation,\n                 *args, **kwargs)\n\n\n@register_model\ndef cct_14_7x2_224(pretrained=False, progress=False,\n                   img_size=224, positional_embedding='learnable', num_classes=1000,\n                   aggregation=None, *args, **kwargs):\n    return cct_14('cct_14_7x2_224', pretrained, progress,\n                  kernel_size=7, n_conv_layers=2,\n                  img_size=img_size, positional_embedding=positional_embedding,\n                  num_classes=num_classes, aggregation=aggregation,\n                  *args, **kwargs)\n\n\n@register_model\ndef cct_14_7x2_384(pretrained=False, progress=False,\n                   img_size=384, positional_embedding='learnable', num_classes=1000,\n                   aggregation=None, *args, **kwargs):\n    return cct_14('cct_14_7x2_384', pretrained, progress,\n                  kernel_size=7, n_conv_layers=2,\n                  img_size=img_size, positional_embedding=positional_embedding,\n                  num_classes=num_classes, aggregation=aggregation,\n                  *args, **kwargs)\n\n\n@register_model\ndef cct_14_7x2_384_fl(pretrained=False, progress=False,\n                      img_size=384, positional_embedding='learnable', num_classes=102,\n                      aggregation=None, *args, **kwargs):\n    return cct_14('cct_14_7x2_384_fl', pretrained, progress,\n                  kernel_size=7, n_conv_layers=2,\n                  img_size=img_size, positional_embedding=positional_embedding,\n                  num_classes=num_classes, aggregation=aggregation,\n                  *args, **kwargs)\n"
  },
  {
    "path": "model/cct/embedder.py",
    "content": "import torch.nn as nn\n\n\nclass Embedder(nn.Module):\n    def __init__(self,\n                 word_embedding_dim=300,\n                 vocab_size=100000,\n                 padding_idx=1,\n                 pretrained_weight=None,\n                 embed_freeze=False,\n                 *args, **kwargs):\n        super(Embedder, self).__init__()\n        self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \\\n            if pretrained_weight is not None else \\\n            nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)\n        self.embeddings.weight.requires_grad = not embed_freeze\n\n    def forward_mask(self, mask):\n        bsz, seq_len = mask.shape\n        new_mask = mask.view(bsz, seq_len, 1)\n        new_mask = new_mask.sum(-1)\n        new_mask = (new_mask > 0)\n        return new_mask\n\n    def forward(self, x, mask=None):\n        embed = self.embeddings(x)\n        embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()\n        return embed, mask\n\n    @staticmethod\n    def init_weight(m):\n        if isinstance(m, nn.Linear):\n            nn.init.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        else:\n            nn.init.normal_(m.weight)\n"
  },
  {
    "path": "model/cct/helpers.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef resize_pos_embed(posemb, posemb_new, num_tokens=1):\n    # Copied from `timm` by Ross Wightman:\n    # github.com/rwightman/pytorch-image-models\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    ntok_new = posemb_new.shape[1]\n    if num_tokens:\n        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]\n        ntok_new -= num_tokens\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n    gs_new = int(math.sqrt(ntok_new))\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n    return posemb\n\n\ndef pe_check(model, state_dict, pe_key='classifier.positional_emb'):\n    if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys():\n        if model.state_dict()[pe_key].shape != state_dict[pe_key].shape:\n            state_dict[pe_key] = resize_pos_embed(state_dict[pe_key],\n                                                  model.state_dict()[pe_key],\n                                                  num_tokens=model.classifier.num_tokens)\n    return state_dict\n"
  },
  {
    "path": "model/cct/stochastic_depth.py",
    "content": "# Thanks to rwightman's timm package\n# github.com:rwightman/pytorch-image-models\n\nimport torch\nimport torch.nn as nn\n\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"\n    Obtained from: github.com:rwightman/pytorch-image-models\n    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "model/cct/tokenizer.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Tokenizer(nn.Module):\n    def __init__(self,\n                 kernel_size, stride, padding,\n                 pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,\n                 n_conv_layers=1,\n                 n_input_channels=3,\n                 n_output_channels=64,\n                 in_planes=64,\n                 activation=None,\n                 max_pool=True,\n                 conv_bias=False):\n        super(Tokenizer, self).__init__()\n\n        n_filter_list = [n_input_channels] + \\\n                        [in_planes for _ in range(n_conv_layers - 1)] + \\\n                        [n_output_channels]\n\n        self.conv_layers = nn.Sequential(\n            *[nn.Sequential(\n                nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],\n                          kernel_size=(kernel_size, kernel_size),\n                          stride=(stride, stride),\n                          padding=(padding, padding), bias=conv_bias),\n                nn.Identity() if activation is None else activation(),\n                nn.MaxPool2d(kernel_size=pooling_kernel_size,\n                             stride=pooling_stride,\n                             padding=pooling_padding) if max_pool else nn.Identity()\n            )\n                for i in range(n_conv_layers)\n            ])\n\n        self.flattener = nn.Flatten(2, 3)\n        self.apply(self.init_weight)\n\n    def sequence_length(self, n_channels=3, height=224, width=224):\n        return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]\n\n    def forward(self, x):\n        return self.flattener(self.conv_layers(x)).transpose(-2, -1)\n\n    @staticmethod\n    def init_weight(m):\n        if isinstance(m, nn.Conv2d):\n            nn.init.kaiming_normal_(m.weight)\n\n\nclass TextTokenizer(nn.Module):\n    def __init__(self,\n                 kernel_size, stride, padding,\n                 pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,\n                 embedding_dim=300,\n                 n_output_channels=128,\n                 activation=None,\n                 max_pool=True,\n                 *args, **kwargs):\n        super(TextTokenizer, self).__init__()\n\n        self.max_pool = max_pool\n        self.conv_layers = nn.Sequential(\n            nn.Conv2d(1, n_output_channels,\n                      kernel_size=(kernel_size, embedding_dim),\n                      stride=(stride, 1),\n                      padding=(padding, 0), bias=False),\n            nn.Identity() if activation is None else activation(),\n            nn.MaxPool2d(\n                kernel_size=(pooling_kernel_size, 1),\n                stride=(pooling_stride, 1),\n                padding=(pooling_padding, 0)\n            ) if max_pool else nn.Identity()\n        )\n\n        self.apply(self.init_weight)\n\n    def seq_len(self, seq_len=32, embed_dim=300):\n        return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]\n\n    def forward_mask(self, mask):\n        new_mask = mask.unsqueeze(1).float()\n        cnn_weight = torch.ones(\n            (1, 1, self.conv_layers[0].kernel_size[0]),\n            device=mask.device,\n            dtype=torch.float)\n        new_mask = F.conv1d(\n            new_mask, cnn_weight, None,\n            self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1)\n        if self.max_pool:\n            new_mask = F.max_pool1d(\n                new_mask, self.conv_layers[2].kernel_size[0],\n                self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False)\n        new_mask = new_mask.squeeze(1)\n        new_mask = (new_mask > 0)\n        return new_mask\n\n    def forward(self, x, mask=None):\n        x = x.unsqueeze(1)\n        x = self.conv_layers(x)\n        x = x.transpose(1, 3).squeeze(1)\n        x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float()\n        return x, mask\n\n    @staticmethod\n    def init_weight(m):\n        if isinstance(m, nn.Conv2d):\n            nn.init.kaiming_normal_(m.weight)\n"
  },
  {
    "path": "model/cct/transformers.py",
    "content": "import torch\nfrom torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init\nimport torch.nn.functional as F\nfrom .stochastic_depth import DropPath\n\n\nclass Attention(Module):\n    \"\"\"\n    Obtained from timm: github.com:rwightman/pytorch-image-models\n    \"\"\"\n\n    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = Linear(dim, dim * 3, bias=False)\n        self.attn_drop = Dropout(attention_dropout)\n        self.proj = Linear(dim, dim)\n        self.proj_drop = Dropout(projection_dropout)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MaskedAttention(Module):\n    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // self.num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = Linear(dim, dim * 3, bias=False)\n        self.attn_drop = Dropout(attention_dropout)\n        self.proj = Linear(dim, dim)\n        self.proj_drop = Dropout(projection_dropout)\n\n    def forward(self, x, mask=None):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n\n        if mask is not None:\n            mask_value = -torch.finfo(attn.dtype).max\n            assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions'\n            mask = mask[:, None, :] * mask[:, :, None]\n            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)\n            attn.masked_fill_(~mask, mask_value)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass TransformerEncoderLayer(Module):\n    \"\"\"\n    Inspired by torch.nn.TransformerEncoderLayer and timm.\n    \"\"\"\n\n    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n                 attention_dropout=0.1, drop_path_rate=0.1):\n        super(TransformerEncoderLayer, self).__init__()\n        self.pre_norm = LayerNorm(d_model)\n        self.self_attn = Attention(dim=d_model, num_heads=nhead,\n                                   attention_dropout=attention_dropout, projection_dropout=dropout)\n\n        self.linear1 = Linear(d_model, dim_feedforward)\n        self.dropout1 = Dropout(dropout)\n        self.norm1 = LayerNorm(d_model)\n        self.linear2 = Linear(dim_feedforward, d_model)\n        self.dropout2 = Dropout(dropout)\n\n        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()\n\n        self.activation = F.gelu\n\n    def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:\n        src = src + self.drop_path(self.self_attn(self.pre_norm(src)))\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))\n        src = src + self.drop_path(self.dropout2(src2))\n        return src\n\n\nclass MaskedTransformerEncoderLayer(Module):\n    \"\"\"\n    Inspired by torch.nn.TransformerEncoderLayer and timm.\n    \"\"\"\n\n    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,\n                 attention_dropout=0.1, drop_path_rate=0.1):\n        super(MaskedTransformerEncoderLayer, self).__init__()\n        self.pre_norm = LayerNorm(d_model)\n        self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead,\n                                         attention_dropout=attention_dropout, projection_dropout=dropout)\n\n        self.linear1 = Linear(d_model, dim_feedforward)\n        self.dropout1 = Dropout(dropout)\n        self.norm1 = LayerNorm(d_model)\n        self.linear2 = Linear(dim_feedforward, d_model)\n        self.dropout2 = Dropout(dropout)\n\n        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()\n\n        self.activation = F.gelu\n\n    def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:\n        src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))\n        src = self.norm1(src)\n        src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))\n        src = src + self.drop_path(self.dropout2(src2))\n        return src\n\n\nclass TransformerClassifier(Module):\n    def __init__(self,\n                 seq_pool=True,\n                 embedding_dim=768,\n                 num_layers=12,\n                 num_heads=12,\n                 mlp_ratio=4.0,\n                 num_classes=1000,\n                 dropout=0.1,\n                 attention_dropout=0.1,\n                 stochastic_depth=0.1,\n                 positional_embedding='learnable',\n                 sequence_length=None):\n        super().__init__()\n        positional_embedding = positional_embedding if \\\n            positional_embedding in ['sine', 'learnable', 'none'] else 'sine'\n        dim_feedforward = int(embedding_dim * mlp_ratio)\n        self.embedding_dim = embedding_dim\n        self.sequence_length = sequence_length\n        self.seq_pool = seq_pool\n\n        assert sequence_length is not None or positional_embedding == 'none', \\\n            f\"Positional embedding is set to {positional_embedding} and\" \\\n            f\" the sequence length was not specified.\"\n\n        if not seq_pool:\n            sequence_length += 1\n            self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),\n                                       requires_grad=True)\n        else:\n            self.attention_pool = Linear(self.embedding_dim, 1)\n\n        if positional_embedding != 'none':\n            if positional_embedding == 'learnable':\n                self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),\n                                                requires_grad=True)\n                init.trunc_normal_(self.positional_emb, std=0.2)\n            else:\n                self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),\n                                                requires_grad=False)\n        else:\n            self.positional_emb = None\n\n        self.dropout = Dropout(p=dropout)\n        dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]\n        self.blocks = ModuleList([\n            TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,\n                                    dim_feedforward=dim_feedforward, dropout=dropout,\n                                    attention_dropout=attention_dropout, drop_path_rate=dpr[i])\n            for i in range(num_layers)])\n        self.norm = LayerNorm(embedding_dim)\n\n        # self.fc = Linear(embedding_dim, num_classes)\n        self.apply(self.init_weight)\n\n    def forward(self, x):\n        if self.positional_emb is None and x.size(1) < self.sequence_length:\n            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)\n\n        if not self.seq_pool:\n            cls_token = self.class_emb.expand(x.shape[0], -1, -1)\n            x = torch.cat((cls_token, x), dim=1)\n\n        if self.positional_emb is not None:\n            x += self.positional_emb\n\n        x = self.dropout(x)\n\n        for blk in self.blocks:\n            x = blk(x)\n        x = self.norm(x)\n        # TODO: TOREMOVE\n        # if self.seq_pool:\n        #    x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)\n        #else:\n        #    x = x[:, 0]\n        # x = self.fc(x)\n        return x\n\n    @staticmethod\n    def init_weight(m):\n        if isinstance(m, Linear):\n            init.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, Linear) and m.bias is not None:\n                init.constant_(m.bias, 0)\n        elif isinstance(m, LayerNorm):\n            init.constant_(m.bias, 0)\n            init.constant_(m.weight, 1.0)\n\n    @staticmethod\n    def sinusoidal_embedding(n_channels, dim):\n        pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]\n                                for p in range(n_channels)])\n        pe[:, 0::2] = torch.sin(pe[:, 0::2])\n        pe[:, 1::2] = torch.cos(pe[:, 1::2])\n        return pe.unsqueeze(0)\n\n\nclass MaskedTransformerClassifier(Module):\n    def __init__(self,\n                 seq_pool=True,\n                 embedding_dim=768,\n                 num_layers=12,\n                 num_heads=12,\n                 mlp_ratio=4.0,\n                 num_classes=1000,\n                 dropout=0.1,\n                 attention_dropout=0.1,\n                 stochastic_depth=0.1,\n                 positional_embedding='sine',\n                 seq_len=None,\n                 *args, **kwargs):\n        super().__init__()\n        positional_embedding = positional_embedding if \\\n            positional_embedding in ['sine', 'learnable', 'none'] else 'sine'\n        dim_feedforward = int(embedding_dim * mlp_ratio)\n        self.embedding_dim = embedding_dim\n        self.seq_len = seq_len\n        self.seq_pool = seq_pool\n\n        assert seq_len is not None or positional_embedding == 'none', \\\n            f\"Positional embedding is set to {positional_embedding} and\" \\\n            f\" the sequence length was not specified.\"\n\n        if not seq_pool:\n            seq_len += 1\n            self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),\n                                       requires_grad=True)\n        else:\n            self.attention_pool = Linear(self.embedding_dim, 1)\n\n        if positional_embedding != 'none':\n            if positional_embedding == 'learnable':\n                seq_len += 1  # padding idx\n                self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim),\n                                                requires_grad=True)\n                init.trunc_normal_(self.positional_emb, std=0.2)\n            else:\n                self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len,\n                                                                          embedding_dim,\n                                                                          padding_idx=True),\n                                                requires_grad=False)\n        else:\n            self.positional_emb = None\n\n        self.dropout = Dropout(p=dropout)\n        dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]\n        self.blocks = ModuleList([\n            MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,\n                                          dim_feedforward=dim_feedforward, dropout=dropout,\n                                          attention_dropout=attention_dropout, drop_path_rate=dpr[i])\n            for i in range(num_layers)])\n        self.norm = LayerNorm(embedding_dim)\n\n        self.fc = Linear(embedding_dim, num_classes)\n        self.apply(self.init_weight)\n\n    def forward(self, x, mask=None):\n        if self.positional_emb is None and x.size(1) < self.seq_len:\n            x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)\n\n        if not self.seq_pool:\n            cls_token = self.class_emb.expand(x.shape[0], -1, -1)\n            x = torch.cat((cls_token, x), dim=1)\n            if mask is not None:\n                mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1)\n                mask = (mask > 0)\n\n        if self.positional_emb is not None:\n            x += self.positional_emb\n\n        x = self.dropout(x)\n\n        for blk in self.blocks:\n            x = blk(x, mask=mask)\n        x = self.norm(x)\n\n        if self.seq_pool:\n            x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)\n        else:\n            x = x[:, 0]\n\n        x = self.fc(x)\n        return x\n\n    @staticmethod\n    def init_weight(m):\n        if isinstance(m, Linear):\n            init.trunc_normal_(m.weight, std=.02)\n            if isinstance(m, Linear) and m.bias is not None:\n                init.constant_(m.bias, 0)\n        elif isinstance(m, LayerNorm):\n            init.constant_(m.bias, 0)\n            init.constant_(m.weight, 1.0)\n\n    @staticmethod\n    def sinusoidal_embedding(n_channels, dim, padding_idx=False):\n        pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]\n                                for p in range(n_channels)])\n        pe[:, 0::2] = torch.sin(pe[:, 0::2])\n        pe[:, 1::2] = torch.cos(pe[:, 1::2])\n        pe = pe.unsqueeze(0)\n        if padding_idx:\n            return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1)\n        return pe\n"
  },
  {
    "path": "model/functional.py",
    "content": "\nimport math\nimport torch\nimport torch.nn.functional as F\n\ndef sare_ind(query, positive, negative):\n    '''all 3 inputs are supposed to be shape 1xn_features'''\n    dist_pos = ((query - positive)**2).sum(1)\n    dist_neg = ((query - negative)**2).sum(1)\n    \n    dist = - torch.cat((dist_pos, dist_neg))\n    dist = F.log_softmax(dist, 0)\n    \n    #loss = (- dist[:, 0]).mean() on a batch\n    loss = -dist[0]\n    return loss\n\ndef sare_joint(query, positive, negatives):\n    '''query and positive have to be 1xn_features; whereas negatives has to be\n    shape n_negative x n_features. n_negative is usually 10'''\n    # NOTE: the implementation is the same if batch_size=1 as all operations\n    # are vectorial. If there were the additional n_batch dimension a different\n    # handling of that situation would have to be implemented here.\n    # This function is declared anyway for the sake of clarity as the 2 should\n    # be called in different situations because, even though there would be\n    # no Exceptions, there would actually be a conceptual error.\n    return sare_ind(query, positive, negatives)\n\ndef mac(x):\n    return F.adaptive_max_pool2d(x, (1,1))\n\ndef spoc(x):\n    return F.adaptive_avg_pool2d(x, (1,1))\n\ndef gem(x, p=3, eps=1e-6, work_with_tokens=False):\n    if work_with_tokens:\n        x = x.permute(0, 2, 1)\n        # unseqeeze to maintain compatibility with Flatten\n        return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3)\n    else:\n        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)\n\ndef rmac(x, L=3, eps=1e-6):\n    ovr = 0.4 # desired overlap of neighboring regions\n    steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension\n    W = x.size(3)\n    H = x.size(2)\n    w = min(W, H)\n    # w2 = math.floor(w/2.0 - 1)\n    b = (max(H, W)-w)/(steps-1)\n    (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension\n    # region overplus per dimension\n    Wd = 0;\n    Hd = 0;\n    if H < W:  \n        Wd = idx.item() + 1\n    elif H > W:\n        Hd = idx.item() + 1\n    v = F.max_pool2d(x, (x.size(-2), x.size(-1)))\n    v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v)\n    for l in range(1, L+1):\n        wl = math.floor(2*w/(l+1))\n        wl2 = math.floor(wl/2 - 1)\n        if l+Wd == 1:\n            b = 0\n        else:\n            b = (W-wl)/(l+Wd-1)\n        cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates\n        if l+Hd == 1:\n            b = 0\n        else:\n            b = (H-wl)/(l+Hd-1)\n        cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates\n        for i_ in cenH.tolist():\n            for j_ in cenW.tolist():\n                if wl == 0:\n                    continue\n                R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:]\n                R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()]\n                vt = F.max_pool2d(R, (R.size(-2), R.size(-1)))\n                vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt)\n                v += vt\n    return v\n\n"
  },
  {
    "path": "model/network.py",
    "content": "\nimport os\nimport torch\nimport logging\nimport torchvision\nfrom torch import nn\nfrom os.path import join\nfrom transformers import ViTModel\nfrom google_drive_downloader import GoogleDriveDownloader as gdd\n\nfrom model.cct import cct_14_7x2_384\nfrom model.aggregation import Flatten\nfrom model.normalization import L2Norm\nimport model.aggregation as aggregation\n\n# Pretrained models on Google Landmarks v2 and Places 365\nPRETRAINED_MODELS = {\n    'resnet18_places'  : '1DnEQXhmPxtBUrRc81nAvT8z17bk-GBj5',\n    'resnet50_places'  : '1zsY4mN4jJ-AsmV3h4hjbT72CBfJsgSGC',\n    'resnet101_places' : '1E1ibXQcg7qkmmmyYgmwMTh7Xf1cDNQXa',\n    'vgg16_places'     : '1UWl1uz6rZ6Nqmp1K5z3GHAIZJmDh4bDu',\n    'resnet18_gldv2'   : '1wkUeUXFXuPHuEvGTXVpuP5BMB-JJ1xke',\n    'resnet50_gldv2'   : '1UDUv6mszlXNC1lv6McLdeBNMq9-kaA70',\n    'resnet101_gldv2'  : '1apiRxMJpDlV0XmKlC5Na_Drg2jtGL-uE',\n    'vgg16_gldv2'      : '10Ov9JdO7gbyz6mB5x0v_VSAUMj91Ta4o'\n}\n\n\nclass GeoLocalizationNet(nn.Module):\n    \"\"\"The used networks are composed of a backbone and an aggregation layer.\n    \"\"\"\n    def __init__(self, args):\n        super().__init__()\n        self.backbone = get_backbone(args)\n        self.arch_name = args.backbone\n        self.aggregation = get_aggregation(args)\n\n        if args.aggregation in [\"gem\", \"spoc\", \"mac\", \"rmac\"]:\n            if args.l2 == \"before_pool\":\n                self.aggregation = nn.Sequential(L2Norm(), self.aggregation, Flatten())\n            elif args.l2 == \"after_pool\":\n                self.aggregation = nn.Sequential(self.aggregation, L2Norm(), Flatten())\n            elif args.l2 == \"none\":\n                self.aggregation = nn.Sequential(self.aggregation, Flatten())\n        \n        if args.fc_output_dim != None:\n            # Concatenate fully connected layer to the aggregation layer\n            self.aggregation = nn.Sequential(self.aggregation,\n                                             nn.Linear(args.features_dim, args.fc_output_dim),\n                                             L2Norm())\n            args.features_dim = args.fc_output_dim\n\n    def forward(self, x):\n        x = self.backbone(x)\n        x = self.aggregation(x)\n        return x\n\n\ndef get_aggregation(args):\n    if args.aggregation == \"gem\":\n        return aggregation.GeM(work_with_tokens=args.work_with_tokens)\n    elif args.aggregation == \"spoc\":\n        return aggregation.SPoC()\n    elif args.aggregation == \"mac\":\n        return aggregation.MAC()\n    elif args.aggregation == \"rmac\":\n        return aggregation.RMAC()\n    elif args.aggregation == \"netvlad\":\n        return aggregation.NetVLAD(clusters_num=args.netvlad_clusters, dim=args.features_dim,\n                                   work_with_tokens=args.work_with_tokens)\n    elif args.aggregation == 'crn':\n        return aggregation.CRN(clusters_num=args.netvlad_clusters, dim=args.features_dim)\n    elif args.aggregation == \"rrm\":\n        return aggregation.RRM(args.features_dim)\n    elif args.aggregation in ['cls', 'seqpool']:\n        return nn.Identity()\n\n\ndef get_pretrained_model(args):\n    if args.pretrain == 'places':  num_classes = 365\n    elif args.pretrain == 'gldv2':  num_classes = 512\n    \n    if args.backbone.startswith(\"resnet18\"):\n        model = torchvision.models.resnet18(num_classes=num_classes)\n    elif args.backbone.startswith(\"resnet50\"):\n        model = torchvision.models.resnet50(num_classes=num_classes)\n    elif args.backbone.startswith(\"resnet101\"):\n        model = torchvision.models.resnet101(num_classes=num_classes)\n    elif args.backbone.startswith(\"vgg16\"):\n        model = torchvision.models.vgg16(num_classes=num_classes)\n    \n    if args.backbone.startswith('resnet'):\n        model_name = args.backbone.split('conv')[0] + \"_\" + args.pretrain\n    else:\n        model_name = args.backbone + \"_\" + args.pretrain\n    file_path = join(\"data\", \"pretrained_nets\", model_name +\".pth\")\n    \n    if not os.path.exists(file_path):\n        gdd.download_file_from_google_drive(file_id=PRETRAINED_MODELS[model_name],\n                                            dest_path=file_path)\n    state_dict = torch.load(file_path, map_location=torch.device('cpu'))\n    model.load_state_dict(state_dict)\n    return model\n\n\ndef get_backbone(args):\n    # The aggregation layer works differently based on the type of architecture\n    args.work_with_tokens = args.backbone.startswith('cct') or args.backbone.startswith('vit')\n    if args.backbone.startswith(\"resnet\"):\n        if args.pretrain in ['places', 'gldv2']:\n            backbone = get_pretrained_model(args)\n        elif args.backbone.startswith(\"resnet18\"):\n            backbone = torchvision.models.resnet18(pretrained=True)\n        elif args.backbone.startswith(\"resnet50\"):\n            backbone = torchvision.models.resnet50(pretrained=True)\n        elif args.backbone.startswith(\"resnet101\"):\n            backbone = torchvision.models.resnet101(pretrained=True)\n        for name, child in backbone.named_children():\n            # Freeze layers before conv_3\n            if name == \"layer3\":\n                break\n            for params in child.parameters():\n                params.requires_grad = False\n        if args.backbone.endswith(\"conv4\"):\n            logging.debug(f\"Train only conv4_x of the resnet{args.backbone.split('conv')[0]} (remove conv5_x), freeze the previous ones\")\n            layers = list(backbone.children())[:-3]\n        elif args.backbone.endswith(\"conv5\"):\n            logging.debug(f\"Train only conv4_x and conv5_x of the resnet{args.backbone.split('conv')[0]}, freeze the previous ones\")\n            layers = list(backbone.children())[:-2]\n    elif args.backbone == \"vgg16\":\n        if args.pretrain in ['places', 'gldv2']:\n            backbone = get_pretrained_model(args)\n        else:\n            backbone = torchvision.models.vgg16(pretrained=True)\n        layers = list(backbone.features.children())[:-2]\n        for l in layers[:-5]:\n            for p in l.parameters(): p.requires_grad = False\n        logging.debug(\"Train last layers of the vgg16, freeze the previous ones\")\n    elif args.backbone == \"alexnet\":\n        backbone = torchvision.models.alexnet(pretrained=True)\n        layers = list(backbone.features.children())[:-2]\n        for l in layers[:5]:\n            for p in l.parameters(): p.requires_grad = False\n        logging.debug(\"Train last layers of the alexnet, freeze the previous ones\")\n    elif args.backbone.startswith(\"cct\"):\n        if args.backbone.startswith(\"cct384\"):\n            backbone = cct_14_7x2_384(pretrained=True, progress=True, aggregation=args.aggregation)\n        if args.trunc_te:\n            logging.debug(f\"Truncate CCT at transformers encoder {args.trunc_te}\")\n            backbone.classifier.blocks = torch.nn.ModuleList(backbone.classifier.blocks[:args.trunc_te].children())\n        if args.freeze_te:\n            logging.debug(f\"Freeze all the layers up to tranformer encoder {args.freeze_te}\")\n            for p in backbone.parameters():\n                p.requires_grad = False\n            for name, child in backbone.classifier.blocks.named_children():\n                if int(name) > args.freeze_te:\n                    for params in child.parameters():\n                        params.requires_grad = True\n        args.features_dim = 384\n        return backbone\n    elif args.backbone.startswith(\"vit\"):\n        assert args.resize[0] in [224, 384], f'Image size for ViT must be either 224 or 384, but it\\'s {args.resize[0]}'\n        if args.resize[0] == 224:\n            backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')\n        elif args.resize[0] == 384:\n            backbone = ViTModel.from_pretrained('google/vit-base-patch16-384')\n\n        if args.trunc_te:\n            logging.debug(f\"Truncate ViT at transformers encoder {args.trunc_te}\")\n            backbone.encoder.layer = backbone.encoder.layer[:args.trunc_te]\n        if args.freeze_te:\n            logging.debug(f\"Freeze all the layers up to tranformer encoder {args.freeze_te+1}\")\n            for p in backbone.parameters():\n                p.requires_grad = False\n            for name, child in backbone.encoder.layer.named_children():\n                if int(name) > args.freeze_te:\n                    for params in child.parameters():\n                        params.requires_grad = True\n        backbone = VitWrapper(backbone, args.aggregation)\n        \n        args.features_dim = 768\n        return backbone\n\n    backbone = torch.nn.Sequential(*layers)\n    args.features_dim = get_output_channels_dim(backbone)  # Dinamically obtain number of channels in output\n    return backbone\n\n\nclass VitWrapper(nn.Module):\n    def __init__(self, vit_model, aggregation):\n        super().__init__()\n        self.vit_model = vit_model\n        self.aggregation = aggregation\n    def forward(self, x):\n        if self.aggregation in [\"netvlad\", \"gem\"]:\n            return self.vit_model(x).last_hidden_state[:, 1:, :]\n        else:\n            return self.vit_model(x).last_hidden_state[:, 0, :]\n\n\ndef get_output_channels_dim(model):\n    \"\"\"Return the number of channels in the output of a model.\"\"\"\n    return model(torch.ones([1, 3, 224, 224])).shape[1]\n\n"
  },
  {
    "path": "model/normalization.py",
    "content": "\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass L2Norm(nn.Module):\n    def __init__(self, dim=1):\n        super().__init__()\n        self.dim = dim\n    def forward(self, x):\n        return F.normalize(x, p=2, dim=self.dim)\n\n"
  },
  {
    "path": "model/sync_batchnorm/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n#\n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nfrom .batchnorm import set_sbn_eps_mode\nfrom .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d\nfrom .batchnorm import patch_sync_batchnorm, convert_model\nfrom .replicate import DataParallelWithCallback, patch_replication_callback\n"
  },
  {
    "path": "model/sync_batchnorm/batchnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n#\n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport collections\nimport contextlib\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\ntry:\n    from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\nexcept ImportError:\n    ReduceAddCoalesced = Broadcast = None\n\ntry:\n    from jactorch.parallel.comm import SyncMaster\n    from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback\nexcept ImportError:\n    from .comm import SyncMaster\n    from .replicate import DataParallelWithCallback\n\n__all__ = [\n    'set_sbn_eps_mode',\n    'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',\n    'patch_sync_batchnorm', 'convert_model'\n]\n\n\nSBN_EPS_MODE = 'clamp'\n\n\ndef set_sbn_eps_mode(mode):\n    global SBN_EPS_MODE\n    assert mode in ('clamp', 'plus')\n    SBN_EPS_MODE = mode\n\n\ndef _sum_ft(tensor):\n    \"\"\"sum over the first and last dimention\"\"\"\n    return tensor.sum(dim=0).sum(dim=-1)\n\n\ndef _unsqueeze_ft(tensor):\n    \"\"\"add new dimensions at the front and the tail\"\"\"\n    return tensor.unsqueeze(0).unsqueeze(-1)\n\n\n_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n\n\nclass _SynchronizedBatchNorm(_BatchNorm):\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):\n        assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'\n\n        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,\n                                                     track_running_stats=track_running_stats)\n\n        if not self.track_running_stats:\n            import warnings\n            warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')\n\n        self._sync_master = SyncMaster(self._data_parallel_master)\n\n        self._is_parallel = False\n        self._parallel_id = None\n        self._slave_pipe = None\n\n    def forward(self, input):\n        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n        if not (self._is_parallel and self.training):\n            return F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                self.training, self.momentum, self.eps)\n\n        # Resize the input to (B, C, -1).\n        input_shape = input.size()\n        assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)\n        input = input.view(input.size(0), self.num_features, -1)\n\n        # Compute the sum and square-sum.\n        sum_size = input.size(0) * input.size(2)\n        input_sum = _sum_ft(input)\n        input_ssum = _sum_ft(input ** 2)\n\n        # Reduce-and-broadcast the statistics.\n        if self._parallel_id == 0:\n            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n        else:\n            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n\n        # Compute the output.\n        if self.affine:\n            # MJY:: Fuse the multiplication for speed.\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n        else:\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n\n        # Reshape it.\n        return output.view(input_shape)\n\n    def __data_parallel_replicate__(self, ctx, copy_id):\n        self._is_parallel = True\n        self._parallel_id = copy_id\n\n        # parallel_id == 0 means master device.\n        if self._parallel_id == 0:\n            ctx.sync_master = self._sync_master\n        else:\n            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n\n    def _data_parallel_master(self, intermediates):\n        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n\n        # Always using same \"device order\" makes the ReduceAdd operation faster.\n        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n\n        to_reduce = [i[1][:2] for i in intermediates]\n        to_reduce = [j for i in to_reduce for j in i]  # flatten\n        target_gpus = [i[1].sum.get_device() for i in intermediates]\n\n        sum_size = sum([i[1].sum_size for i in intermediates])\n        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n\n        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n\n        outputs = []\n        for i, rec in enumerate(intermediates):\n            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))\n\n        return outputs\n\n    def _compute_mean_std(self, sum_, ssum, size):\n        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n        also maintains the moving average on the master device.\"\"\"\n        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n        mean = sum_ / size\n        sumvar = ssum - sum_ * mean\n        unbias_var = sumvar / (size - 1)\n        bias_var = sumvar / size\n\n        if hasattr(torch, 'no_grad'):\n            with torch.no_grad():\n                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n        else:\n            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n\n        if SBN_EPS_MODE == 'clamp':\n            return mean, bias_var.clamp(self.eps) ** -0.5\n        elif SBN_EPS_MODE == 'plus':\n            return mean, (bias_var + self.eps) ** -0.5\n        else:\n            raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))\n\n\nclass SynchronizedBatchNorm1d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a\n    mini-batch.\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm1d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n\n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of size\n            `batch_size x num_features [x width]`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape::\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n\n\nclass SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n    of 3d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n\n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape::\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n\n\nclass SynchronizedBatchNorm3d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5d input that is seen as a mini-batch\n    of 4d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm3d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n\n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm\n    or Spatio-temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x depth x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape::\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 5D input (got {}D input)'\n                             .format(input.dim()))\n\n\n@contextlib.contextmanager\ndef patch_sync_batchnorm():\n    import torch.nn as nn\n\n    backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d\n\n    nn.BatchNorm1d = SynchronizedBatchNorm1d\n    nn.BatchNorm2d = SynchronizedBatchNorm2d\n    nn.BatchNorm3d = SynchronizedBatchNorm3d\n\n    yield\n\n    nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup\n\n\ndef convert_model(module):\n    \"\"\"Traverse the input module and its child recursively\n       and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d\n       to SynchronizedBatchNorm*N*d\n\n    Args:\n        module: the input module needs to be convert to SyncBN model\n\n    Examples:\n        >>> import torch.nn as nn\n        >>> import torchvision\n        >>> # m is a standard pytorch model\n        >>> m = torchvision.models.resnet18(True)\n        >>> m = nn.DataParallel(m)\n        >>> # after convert, m is using SyncBN\n        >>> m = convert_model(m)\n    \"\"\"\n    if isinstance(module, torch.nn.DataParallel):\n        mod = module.module\n        mod = convert_model(mod)\n        mod = DataParallelWithCallback(mod, device_ids=module.device_ids)\n        return mod\n\n    mod = module\n    for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,\n                                        torch.nn.modules.batchnorm.BatchNorm2d,\n                                        torch.nn.modules.batchnorm.BatchNorm3d],\n                                       [SynchronizedBatchNorm1d,\n                                        SynchronizedBatchNorm2d,\n                                        SynchronizedBatchNorm3d]):\n        if isinstance(module, pth_module):\n            mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)\n            mod.running_mean = module.running_mean\n            mod.running_var = module.running_var\n            if module.affine:\n                mod.weight.data = module.weight.data.clone().detach()\n                mod.bias.data = module.bias.data.clone().detach()\n\n    for name, child in module.named_children():\n        mod.add_module(name, convert_model(child))\n\n    return mod\n"
  },
  {
    "path": "model/sync_batchnorm/batchnorm_reimpl.py",
    "content": "#! /usr/bin/env python3\n# -*- coding: utf-8 -*-\n# File   : batchnorm_reimpl.py\n# Author : acgtyrant\n# Date   : 11/01/2018\n#\n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\n\n__all__ = ['BatchNorm2dReimpl']\n\n\nclass BatchNorm2dReimpl(nn.Module):\n    \"\"\"\n    A re-implementation of batch normalization, used for testing the numerical\n    stability.\n\n    Author: acgtyrant\n    See also:\n    https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14\n    \"\"\"\n    def __init__(self, num_features, eps=1e-5, momentum=0.1):\n        super().__init__()\n\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.weight = nn.Parameter(torch.empty(num_features))\n        self.bias = nn.Parameter(torch.empty(num_features))\n        self.register_buffer('running_mean', torch.zeros(num_features))\n        self.register_buffer('running_var', torch.ones(num_features))\n        self.reset_parameters()\n\n    def reset_running_stats(self):\n        self.running_mean.zero_()\n        self.running_var.fill_(1)\n\n    def reset_parameters(self):\n        self.reset_running_stats()\n        init.uniform_(self.weight)\n        init.zeros_(self.bias)\n\n    def forward(self, input_):\n        batchsize, channels, height, width = input_.size()\n        numel = batchsize * height * width\n        input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)\n        sum_ = input_.sum(1)\n        sum_of_square = input_.pow(2).sum(1)\n        mean = sum_ / numel\n        sumvar = sum_of_square - sum_ * mean\n\n        self.running_mean = (\n                (1 - self.momentum) * self.running_mean\n                + self.momentum * mean.detach()\n        )\n        unbias_var = sumvar / (numel - 1)\n        self.running_var = (\n                (1 - self.momentum) * self.running_var\n                + self.momentum * unbias_var.detach()\n        )\n\n        bias_var = sumvar / numel\n        inv_std = 1 / (bias_var + self.eps).pow(0.5)\n        output = (\n                (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *\n                self.weight.unsqueeze(1) + self.bias.unsqueeze(1))\n\n        return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()\n\n"
  },
  {
    "path": "model/sync_batchnorm/comm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport queue\nimport collections\nimport threading\n\n__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']\n\n\nclass FutureResult(object):\n    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n\n    def __init__(self):\n        self._result = None\n        self._lock = threading.Lock()\n        self._cond = threading.Condition(self._lock)\n\n    def put(self, result):\n        with self._lock:\n            assert self._result is None, 'Previous result has\\'t been fetched.'\n            self._result = result\n            self._cond.notify()\n\n    def get(self):\n        with self._lock:\n            if self._result is None:\n                self._cond.wait()\n\n            res = self._result\n            self._result = None\n            return res\n\n\n_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n\n\nclass SlavePipe(_SlavePipeBase):\n    \"\"\"Pipe for master-slave communication.\"\"\"\n\n    def run_slave(self, msg):\n        self.queue.put((self.identifier, msg))\n        ret = self.result.get()\n        self.queue.put(True)\n        return ret\n\n\nclass SyncMaster(object):\n    \"\"\"An abstract `SyncMaster` object.\n\n    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n    and passed to a registered callback.\n    - After receiving the messages, the master device should gather the information and determine to message passed\n    back to each slave devices.\n    \"\"\"\n\n    def __init__(self, master_callback):\n        \"\"\"\n\n        Args:\n            master_callback: a callback to be invoked after having collected messages from slave devices.\n        \"\"\"\n        self._master_callback = master_callback\n        self._queue = queue.Queue()\n        self._registry = collections.OrderedDict()\n        self._activated = False\n\n    def __getstate__(self):\n        return {'master_callback': self._master_callback}\n\n    def __setstate__(self, state):\n        self.__init__(state['master_callback'])\n\n    def register_slave(self, identifier):\n        \"\"\"\n        Register an slave device.\n\n        Args:\n            identifier: an identifier, usually is the device id.\n\n        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n\n        \"\"\"\n        if self._activated:\n            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n            self._activated = False\n            self._registry.clear()\n        future = FutureResult()\n        self._registry[identifier] = _MasterRegistry(future)\n        return SlavePipe(identifier, self._queue, future)\n\n    def run_master(self, master_msg):\n        \"\"\"\n        Main entry for the master device in each forward pass.\n        The messages were first collected from each devices (including the master device), and then\n        an callback will be invoked to compute the message to be sent back to each devices\n        (including the master device).\n\n        Args:\n            master_msg: the message that the master want to send to itself. This will be placed as the first\n            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n\n        Returns: the message to be sent back to the master device.\n\n        \"\"\"\n        self._activated = True\n\n        intermediates = [(0, master_msg)]\n        for i in range(self.nr_slaves):\n            intermediates.append(self._queue.get())\n\n        results = self._master_callback(intermediates)\n        assert results[0][0] == 0, 'The first result should belongs to the master.'\n\n        for i, res in results:\n            if i == 0:\n                continue\n            self._registry[i].result.put(res)\n\n        for i in range(self.nr_slaves):\n            assert self._queue.get() is True\n\n        return results[0][1]\n\n    @property\n    def nr_slaves(self):\n        return len(self._registry)\n"
  },
  {
    "path": "model/sync_batchnorm/replicate.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback',\n    'patch_replication_callback'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback(DataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "model/sync_batchnorm/unittest.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n#\n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport unittest\nimport torch\n\n\nclass TorchTestCase(unittest.TestCase):\n    def assertTensorClose(self, x, y):\n        adiff = float((x - y).abs().max())\n        if (y == 0).all():\n            rdiff = 'NaN'\n        else:\n            rdiff = float((adiff / y).abs().max())\n\n        message = (\n            'Tensor close check failed\\n'\n            'adiff={}\\n'\n            'rdiff={}\\n'\n        ).format(adiff, rdiff)\n        self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)\n\n"
  },
  {
    "path": "parser.py",
    "content": "\nimport os\nimport torch\nimport argparse\n\n\ndef parse_arguments():\n    parser = argparse.ArgumentParser(description=\"Benchmarking Visual Geolocalization\",\n                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    # Training parameters\n    parser.add_argument(\"--train_batch_size\", type=int, default=4,\n                        help=\"Number of triplets (query, pos, negs) in a batch. Each triplet consists of 12 images\")\n    parser.add_argument(\"--infer_batch_size\", type=int, default=16,\n                        help=\"Batch size for inference (caching and testing)\")\n    parser.add_argument(\"--criterion\", type=str, default='triplet', help='loss to be used',\n                        choices=[\"triplet\", \"sare_ind\", \"sare_joint\"])\n    parser.add_argument(\"--margin\", type=float, default=0.1,\n                        help=\"margin for the triplet loss\")\n    parser.add_argument(\"--epochs_num\", type=int, default=1000,\n                        help=\"number of epochs to train for\")\n    parser.add_argument(\"--patience\", type=int, default=3)\n    parser.add_argument(\"--lr\", type=float, default=0.00001, help=\"_\")\n    parser.add_argument(\"--lr_crn_layer\", type=float, default=5e-3, help=\"Learning rate for the CRN layer\")\n    parser.add_argument(\"--lr_crn_net\", type=float, default=5e-4, help=\"Learning rate to finetune pretrained network when using CRN\")\n    parser.add_argument(\"--optim\", type=str, default=\"adam\", help=\"_\", choices=[\"adam\", \"sgd\"])\n    parser.add_argument(\"--cache_refresh_rate\", type=int, default=1000,\n                        help=\"How often to refresh cache, in number of queries\")\n    parser.add_argument(\"--queries_per_epoch\", type=int, default=5000,\n                        help=\"How many queries to consider for one epoch. Must be multiple of cache_refresh_rate\")\n    parser.add_argument(\"--negs_num_per_query\", type=int, default=10,\n                        help=\"How many negatives to consider per each query in the loss\")\n    parser.add_argument(\"--neg_samples_num\", type=int, default=1000,\n                        help=\"How many negatives to use to compute the hardest ones\")\n    parser.add_argument(\"--mining\", type=str, default=\"partial\", choices=[\"partial\", \"full\", \"random\", \"msls_weighted\"])\n    # Model parameters\n    parser.add_argument(\"--backbone\", type=str, default=\"resnet18conv4\",\n                        choices=[\"alexnet\", \"vgg16\", \"resnet18conv4\", \"resnet18conv5\",\n                                 \"resnet50conv4\", \"resnet50conv5\", \"resnet101conv4\", \"resnet101conv5\",\n                                 \"cct384\", \"vit\"], help=\"_\")\n    parser.add_argument(\"--l2\", type=str, default=\"before_pool\", choices=[\"before_pool\", \"after_pool\", \"none\"],\n                        help=\"When (and if) to apply the l2 norm with shallow aggregation layers\")\n    parser.add_argument(\"--aggregation\", type=str, default=\"netvlad\", choices=[\"netvlad\", \"gem\", \"spoc\", \"mac\", \"rmac\", \"crn\", \"rrm\",\n                                                                               \"cls\", \"seqpool\"])\n    parser.add_argument('--netvlad_clusters', type=int, default=64, help=\"Number of clusters for NetVLAD layer.\")\n    parser.add_argument('--pca_dim', type=int, default=None, help=\"PCA dimension (number of principal components). If None, PCA is not used.\")\n    parser.add_argument('--fc_output_dim', type=int, default=None,\n                        help=\"Output dimension of fully connected layer. If None, don't use a fully connected layer.\")\n    parser.add_argument('--pretrain', type=str, default=\"imagenet\", choices=['imagenet', 'gldv2', 'places'],\n                        help=\"Select the pretrained weights for the starting network\")\n    parser.add_argument(\"--off_the_shelf\", type=str, default=\"imagenet\", choices=[\"imagenet\", \"radenovic_sfm\", \"radenovic_gldv1\", \"naver\"],\n                        help=\"Off-the-shelf networks from popular GitHub repos. Only with ResNet-50/101 + GeM + FC 2048\")\n    parser.add_argument(\"--trunc_te\", type=int, default=None, choices=list(range(0, 14)))\n    parser.add_argument(\"--freeze_te\", type=int, default=None, choices=list(range(-1, 14)))\n    # Initialization parameters\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Path to load checkpoint from, for resuming training or testing.\")\n    # Other parameters\n    parser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cuda\", \"cpu\"])\n    parser.add_argument(\"--num_workers\", type=int, default=8, help=\"num_workers for all dataloaders\")\n    parser.add_argument('--resize', type=int, default=[480, 640], nargs=2, help=\"Resizing shape for images (HxW).\")\n    parser.add_argument('--test_method', type=str, default=\"hard_resize\",\n                        choices=[\"hard_resize\", \"single_query\", \"central_crop\", \"five_crops\", \"nearest_crop\", \"maj_voting\"],\n                        help=\"This includes pre/post-processing methods and prediction refinement\")\n    parser.add_argument(\"--majority_weight\", type=float, default=0.01,\n                        help=\"only for majority voting, scale factor, the higher it is the more importance is given to agreement\")\n    parser.add_argument(\"--efficient_ram_testing\", action='store_true', help=\"_\")\n    parser.add_argument(\"--val_positive_dist_threshold\", type=int, default=25, help=\"_\")\n    parser.add_argument(\"--train_positives_dist_threshold\", type=int, default=10, help=\"_\")\n    parser.add_argument('--recall_values', type=int, default=[1, 5, 10, 20], nargs=\"+\",\n                        help=\"Recalls to be computed, such as R@5.\")\n    # Data augmentation parameters\n    parser.add_argument(\"--brightness\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--contrast\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--saturation\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--hue\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--rand_perspective\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--horizontal_flip\", action='store_true', help=\"_\")\n    parser.add_argument(\"--random_resized_crop\", type=float, default=0, help=\"_\")\n    parser.add_argument(\"--random_rotation\", type=float, default=0, help=\"_\")\n    # Paths parameters\n    parser.add_argument(\"--datasets_folder\", type=str, default=None, help=\"Path with all datasets\")\n    parser.add_argument(\"--dataset_name\", type=str, default=\"pitts30k\", help=\"Relative path of the dataset\")\n    parser.add_argument(\"--pca_dataset_folder\", type=str, default=None,\n                        help=\"Path with images to be used to compute PCA (ie: pitts30k/images/train\")\n    parser.add_argument(\"--save_dir\", type=str, default=\"default\",\n                        help=\"Folder name of the current run (saved in ./logs/)\")\n    args = parser.parse_args()\n    \n    if args.datasets_folder is None:\n        try:\n            args.datasets_folder = os.environ['DATASETS_FOLDER']\n        except KeyError:\n            raise Exception(\"You should set the parameter --datasets_folder or export \" +\n                            \"the DATASETS_FOLDER environment variable as such \\n\" +\n                            \"export DATASETS_FOLDER=../datasets_vg/datasets\")\n    \n    if args.aggregation == \"crn\" and args.resume is None:\n        raise ValueError(\"CRN must be resumed from a trained NetVLAD checkpoint, but you set resume=None.\")\n    \n    if args.queries_per_epoch % args.cache_refresh_rate != 0:\n        raise ValueError(\"Ensure that queries_per_epoch is divisible by cache_refresh_rate, \" +\n                         f\"because {args.queries_per_epoch} is not divisible by {args.cache_refresh_rate}\")\n    \n    if torch.cuda.device_count() >= 2 and args.criterion in ['sare_joint', \"sare_ind\"]:\n        raise NotImplementedError(\"SARE losses are not implemented for multiple GPUs, \" +\n                                  f\"but you're using {torch.cuda.device_count()} GPUs and {args.criterion} loss.\")\n    \n    if args.mining == \"msls_weighted\" and args.dataset_name != \"msls\":\n        raise ValueError(\"msls_weighted mining can only be applied to msls dataset, but you're using it on {args.dataset_name}\")\n    \n    if args.off_the_shelf in [\"radenovic_sfm\", \"radenovic_gldv1\", \"naver\"]:\n        if args.backbone not in [\"resnet50conv5\", \"resnet101conv5\"] or args.aggregation != \"gem\" or args.fc_output_dim != 2048:\n            raise ValueError(\"Off-the-shelf models are trained only with ResNet-50/101 + GeM + FC 2048\")\n    \n    if args.pca_dim is not None and args.pca_dataset_folder is None:\n        raise ValueError(\"Please specify --pca_dataset_folder when using pca\")\n    \n    if args.backbone == \"vit\":\n        if args.resize != [224, 224] and args.resize != [384, 384]:\n            raise ValueError(f'Image size for ViT must be either 224 or 384 {args.resize}')\n    if args.backbone == \"cct384\":\n        if args.resize != [384, 384]:\n            raise ValueError(f'Image size for CCT384 must be 384, but it is {args.resize}')\n    \n    if args.backbone in [\"alexnet\", \"vgg16\", \"resnet18conv4\", \"resnet18conv5\",\n                         \"resnet50conv4\", \"resnet50conv5\", \"resnet101conv4\", \"resnet101conv5\"]:\n        if args.aggregation in [\"cls\", \"seqpool\"]:\n            raise ValueError(f\"CNNs like {args.backbone} can't work with aggregation {args.aggregation}\")\n    if args.backbone in [\"cct384\"]:\n        if args.aggregation in [\"spoc\", \"mac\", \"rmac\", \"crn\", \"rrm\"]:\n            raise ValueError(f\"CCT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls, seqpool]\")\n    if args.backbone == \"vit\":\n        if args.aggregation not in [\"cls\", \"gem\", \"netvlad\"]:\n            raise ValueError(f\"ViT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls]\")\n\n    return args\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy==1.19.4\ntorchvision==0.8.1\npsutil==5.6.7\nfaiss_cpu==1.5.3\ntqdm==4.48.2\ntorch==1.7.0\nPillow==8.2.0\nscikit_learn==0.24.1\ntorchscan==0.1.1\ngoogledrivedownloader==0.4\nrequests==2.26.0\ntimm==0.4.12\ntransformers==4.8.2\neinops\n"
  },
  {
    "path": "test.py",
    "content": "\nimport faiss\nimport torch\nimport logging\nimport numpy as np\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.dataset import Subset\n\n\ndef test_efficient_ram_usage(args, eval_ds, model, test_method=\"hard_resize\"):\n    \"\"\"This function gives the same output as test(), but uses much less RAM.\n    This can be useful when testing with large descriptors (e.g. NetVLAD) on large datasets (e.g. San Francisco).\n    Obviously it is slower than test(), and can't be used with PCA.\n    \"\"\"\n    \n    model = model.eval()\n    if test_method == 'nearest_crop' or test_method == \"maj_voting\":\n        distances = np.empty([eval_ds.queries_num * 5, eval_ds.database_num], dtype=np.float32)\n    else:\n        distances = np.empty([eval_ds.queries_num, eval_ds.database_num], dtype=np.float32)\n\n    with torch.no_grad():\n        if test_method == 'nearest_crop' or test_method == 'maj_voting':\n            queries_features = np.ones((eval_ds.queries_num * 5, args.features_dim), dtype=\"float32\")\n        else:\n            queries_features = np.ones((eval_ds.queries_num, args.features_dim), dtype=\"float32\")\n        logging.debug(\"Extracting queries features for evaluation/testing\")\n        queries_infer_batch_size = 1 if test_method == \"single_query\" else args.infer_batch_size\n        eval_ds.test_method = test_method\n        queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))\n        queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,\n                                        batch_size=queries_infer_batch_size, pin_memory=(args.device == \"cuda\"))\n        for inputs, indices in tqdm(queries_dataloader, ncols=100):\n            if test_method == \"five_crops\" or test_method == \"nearest_crop\" or test_method == 'maj_voting':\n                inputs = torch.cat(tuple(inputs))  # shape = 5*bs x 3 x 480 x 480\n            features = model(inputs.to(args.device))\n            if test_method == \"five_crops\":  # Compute mean along the 5 crops\n                features = torch.stack(torch.split(features, 5)).mean(1)\n            if test_method == \"nearest_crop\" or test_method == 'maj_voting':\n                start_idx = (indices[0] - eval_ds.database_num) * 5\n                end_idx = start_idx + indices.shape[0] * 5\n                indices = np.arange(start_idx, end_idx)\n                queries_features[indices, :] = features.cpu().numpy()\n            else:\n                queries_features[indices.numpy()-eval_ds.database_num, :] = features.cpu().numpy()\n\n        queries_features = torch.tensor(queries_features).type(torch.float32).cuda()\n        \n        logging.debug(\"Extracting database features for evaluation/testing\")\n        # For database use \"hard_resize\", although it usually has no effect because database images have same resolution\n        eval_ds.test_method = \"hard_resize\"\n        database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))\n        database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,\n                                         batch_size=args.infer_batch_size, pin_memory=(args.device == \"cuda\"))\n        for inputs, indices in tqdm(database_dataloader, ncols=100):\n            inputs = inputs.to(args.device)\n            features = model(inputs)\n            for pn, (index, pred_feature) in enumerate(zip(indices, features)):\n                distances[:, index] = ((queries_features-pred_feature)**2).sum(1).cpu().numpy()\n        del features, queries_features, pred_feature\n        \n    predictions = distances.argsort(axis=1)[:, :max(args.recall_values)]\n    \n    if test_method == 'nearest_crop':\n        distances = np.array([distances[row, index] for row, index in enumerate(predictions)])\n        distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5))\n        predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5))\n        for q in range(eval_ds.queries_num):\n            # sort predictions by distance\n            sort_idx = np.argsort(distances[q])\n            predictions[q] = predictions[q, sort_idx]\n            # remove duplicated predictions, i.e. keep only the closest ones\n            _, unique_idx = np.unique(predictions[q], return_index=True)\n            # unique_idx is sorted based on the unique values, sort it again\n            predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20]\n        predictions = predictions[:, :20]  # keep only the closer 20 predictions for each\n    elif test_method == 'maj_voting':\n        distances = np.array([distances[row, index] for row, index in enumerate(predictions)])\n        distances = np.reshape(distances, (eval_ds.queries_num, 5, 20))\n        predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20))\n        for q in range(eval_ds.queries_num):\n            # votings, modify distances in-place\n            top_n_voting('top1', predictions[q], distances[q], args.majority_weight)\n            top_n_voting('top5', predictions[q], distances[q], args.majority_weight)\n            top_n_voting('top10', predictions[q], distances[q], args.majority_weight)\n\n            # flatten dist and preds from 5, 20 -> 20*5\n            # and then proceed as usual to keep only first 20\n            dists = distances[q].flatten()\n            preds = predictions[q].flatten()\n\n            # sort predictions by distance\n            sort_idx = np.argsort(dists)\n            preds = preds[sort_idx]\n            # remove duplicated predictions, i.e. keep only the closest ones\n            _, unique_idx = np.unique(preds, return_index=True)\n            # unique_idx is sorted based on the unique values, sort it again\n            # here the row corresponding to the first crop is used as a\n            # 'buffer' for each query, and in the end the dimension\n            # relative to crops is eliminated\n            predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20]\n        predictions = predictions[:, 0, :20]  # keep only the closer 20 predictions for each query\n    del distances\n    \n    #### For each query, check if the predictions are correct\n    positives_per_query = eval_ds.get_positives()\n    # args.recall_values by default is [1, 5, 10, 20]\n    recalls = np.zeros(len(args.recall_values))\n    for query_index, pred in enumerate(predictions):\n        for i, n in enumerate(args.recall_values):\n            if np.any(np.in1d(pred[:n], positives_per_query[query_index])):\n                recalls[i:] += 1\n                break\n    \n    recalls = recalls / eval_ds.queries_num * 100\n    recalls_str = \", \".join([f\"R@{val}: {rec:.1f}\" for val, rec in zip(args.recall_values, recalls)])\n    return recalls, recalls_str\n\n\ndef test(args, eval_ds, model, test_method=\"hard_resize\", pca=None):\n    \"\"\"Compute features of the given dataset and compute the recalls.\"\"\"\n    \n    assert test_method in [\"hard_resize\", \"single_query\", \"central_crop\", \"five_crops\",\n                           \"nearest_crop\", \"maj_voting\"], f\"test_method can't be {test_method}\"\n    \n    if args.efficient_ram_testing:\n        return test_efficient_ram_usage(args, eval_ds, model, test_method)\n    \n    model = model.eval()\n    with torch.no_grad():\n        logging.debug(\"Extracting database features for evaluation/testing\")\n        # For database use \"hard_resize\", although it usually has no effect because database images have same resolution\n        eval_ds.test_method = \"hard_resize\"\n        database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))\n        database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,\n                                         batch_size=args.infer_batch_size, pin_memory=(args.device == \"cuda\"))\n        \n        if test_method == \"nearest_crop\" or test_method == 'maj_voting':\n            all_features = np.empty((5 * eval_ds.queries_num + eval_ds.database_num, args.features_dim), dtype=\"float32\")\n        else:\n            all_features = np.empty((len(eval_ds), args.features_dim), dtype=\"float32\")\n\n        for inputs, indices in tqdm(database_dataloader, ncols=100):\n            features = model(inputs.to(args.device))\n            features = features.cpu().numpy()\n            if pca is not None:\n                features = pca.transform(features)\n            all_features[indices.numpy(), :] = features\n        \n        logging.debug(\"Extracting queries features for evaluation/testing\")\n        queries_infer_batch_size = 1 if test_method == \"single_query\" else args.infer_batch_size\n        eval_ds.test_method = test_method\n        queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))\n        queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,\n                                        batch_size=queries_infer_batch_size, pin_memory=(args.device == \"cuda\"))\n        for inputs, indices in tqdm(queries_dataloader, ncols=100):\n            if test_method == \"five_crops\" or test_method == \"nearest_crop\" or test_method == 'maj_voting':\n                inputs = torch.cat(tuple(inputs))  # shape = 5*bs x 3 x 480 x 480\n            features = model(inputs.to(args.device))\n            if test_method == \"five_crops\":  # Compute mean along the 5 crops\n                features = torch.stack(torch.split(features, 5)).mean(1)\n            features = features.cpu().numpy()\n            if pca is not None:\n                features = pca.transform(features)\n            \n            if test_method == \"nearest_crop\" or test_method == 'maj_voting':  # store the features of all 5 crops\n                start_idx = eval_ds.database_num + (indices[0] - eval_ds.database_num) * 5\n                end_idx = start_idx + indices.shape[0] * 5\n                indices = np.arange(start_idx, end_idx)\n                all_features[indices, :] = features\n            else:\n                all_features[indices.numpy(), :] = features\n    \n    queries_features = all_features[eval_ds.database_num:]\n    database_features = all_features[:eval_ds.database_num]\n    \n    faiss_index = faiss.IndexFlatL2(args.features_dim)\n    faiss_index.add(database_features)\n    del database_features, all_features\n    \n    logging.debug(\"Calculating recalls\")\n    distances, predictions = faiss_index.search(queries_features, max(args.recall_values))\n    \n    if test_method == 'nearest_crop':\n        distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5))\n        predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5))\n        for q in range(eval_ds.queries_num):\n            # sort predictions by distance\n            sort_idx = np.argsort(distances[q])\n            predictions[q] = predictions[q, sort_idx]\n            # remove duplicated predictions, i.e. keep only the closest ones\n            _, unique_idx = np.unique(predictions[q], return_index=True)\n            # unique_idx is sorted based on the unique values, sort it again\n            predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20]\n        predictions = predictions[:, :20]  # keep only the closer 20 predictions for each query\n    elif test_method == 'maj_voting':\n        distances = np.reshape(distances, (eval_ds.queries_num, 5, 20))\n        predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20))\n        for q in range(eval_ds.queries_num):\n            # votings, modify distances in-place\n            top_n_voting('top1', predictions[q], distances[q], args.majority_weight)\n            top_n_voting('top5', predictions[q], distances[q], args.majority_weight)\n            top_n_voting('top10', predictions[q], distances[q], args.majority_weight)\n\n            # flatten dist and preds from 5, 20 -> 20*5\n            # and then proceed as usual to keep only first 20\n            dists = distances[q].flatten()\n            preds = predictions[q].flatten()\n\n            # sort predictions by distance\n            sort_idx = np.argsort(dists)\n            preds = preds[sort_idx]\n            # remove duplicated predictions, i.e. keep only the closest ones\n            _, unique_idx = np.unique(preds, return_index=True)\n            # unique_idx is sorted based on the unique values, sort it again\n            # here the row corresponding to the first crop is used as a\n            # 'buffer' for each query, and in the end the dimension\n            # relative to crops is eliminated\n            predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20]\n        predictions = predictions[:, 0, :20]  # keep only the closer 20 predictions for each query\n\n    #### For each query, check if the predictions are correct\n    positives_per_query = eval_ds.get_positives()\n    # args.recall_values by default is [1, 5, 10, 20]\n    recalls = np.zeros(len(args.recall_values))\n    for query_index, pred in enumerate(predictions):\n        for i, n in enumerate(args.recall_values):\n            if np.any(np.in1d(pred[:n], positives_per_query[query_index])):\n                recalls[i:] += 1\n                break\n    # Divide by the number of queries*100, so the recalls are in percentages\n    recalls = recalls / eval_ds.queries_num * 100\n    recalls_str = \", \".join([f\"R@{val}: {rec:.1f}\" for val, rec in zip(args.recall_values, recalls)])\n    return recalls, recalls_str\n\n\ndef top_n_voting(topn, predictions, distances, maj_weight):\n    if topn == 'top1':\n        n = 1\n        selected = 0\n    elif topn == 'top5':\n        n = 5\n        selected = slice(0, 5)\n    elif topn == 'top10':\n        n = 10\n        selected = slice(0, 10)\n    # find predictions that repeat in the first, first five,\n    # or fist ten columns for each crop\n    vals, counts = np.unique(predictions[:, selected], return_counts=True)\n    # for each prediction that repeats more than once,\n    # subtract from its score\n    for val, count in zip(vals[counts > 1], counts[counts > 1]):\n        mask = (predictions[:, selected] == val)\n        distances[:, selected][mask] -= maj_weight * count/n\n"
  },
  {
    "path": "train.py",
    "content": "\nimport math\nimport torch\nimport logging\nimport numpy as np\nfrom tqdm import tqdm\nimport torch.nn as nn\nimport multiprocessing\nfrom os.path import join\nfrom datetime import datetime\nimport torchvision.transforms as transforms\nfrom torch.utils.data.dataloader import DataLoader\n\nimport util\nimport test\nimport parser\nimport commons\nimport datasets_ws\nfrom model import network\nfrom model.sync_batchnorm import convert_model\nfrom model.functional import sare_ind, sare_joint\n\ntorch.backends.cudnn.benchmark = True  # Provides a speedup\n#### Initial setup: parser, logging...\nargs = parser.parse_arguments()\nstart_time = datetime.now()\nargs.save_dir = join(\"logs\", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S'))\ncommons.setup_logging(args.save_dir)\ncommons.make_deterministic(args.seed)\nlogging.info(f\"Arguments: {args}\")\nlogging.info(f\"The outputs are being saved in {args.save_dir}\")\nlogging.info(f\"Using {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs\")\n\n#### Creation of Datasets\nlogging.debug(f\"Loading dataset {args.dataset_name} from folder {args.datasets_folder}\")\n\ntriplets_ds = datasets_ws.TripletsDataset(args, args.datasets_folder, args.dataset_name, \"train\", args.negs_num_per_query)\nlogging.info(f\"Train query set: {triplets_ds}\")\n\nval_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, \"val\")\nlogging.info(f\"Val set: {val_ds}\")\n\ntest_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, \"test\")\nlogging.info(f\"Test set: {test_ds}\")\n\n#### Initialize model\nmodel = network.GeoLocalizationNet(args)\nmodel = model.to(args.device)\nif args.aggregation in [\"netvlad\", \"crn\"]:  # If using NetVLAD layer, initialize it\n    if not args.resume:\n        triplets_ds.is_inference = True\n        model.aggregation.initialize_netvlad_layer(args, triplets_ds, model.backbone)\n    args.features_dim *= args.netvlad_clusters\n\nmodel = torch.nn.DataParallel(model)\n\n#### Setup Optimizer and Loss\nif args.aggregation == \"crn\":\n    crn_params = list(model.module.aggregation.crn.parameters())\n    net_params = list(model.module.backbone.parameters()) + \\\n        list([m[1] for m in model.module.aggregation.named_parameters() if not m[0].startswith('crn')])\n    if args.optim == \"adam\":\n        optimizer = torch.optim.Adam([{'params': crn_params, 'lr': args.lr_crn_layer},\n                                      {'params': net_params, 'lr': args.lr_crn_net}])\n        logging.info(\"You're using CRN with Adam, it is advised to use SGD\")\n    elif args.optim == \"sgd\":\n        optimizer = torch.optim.SGD([{'params': crn_params, 'lr': args.lr_crn_layer, 'momentum': 0.9, 'weight_decay': 0.001},\n                                     {'params': net_params, 'lr': args.lr_crn_net, 'momentum': 0.9, 'weight_decay': 0.001}])\nelse:\n    if args.optim == \"adam\":\n        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    elif args.optim == \"sgd\":\n        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001)\n\nif args.criterion == \"triplet\":\n    criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction=\"sum\")\nelif args.criterion == \"sare_ind\":\n    criterion_triplet = sare_ind\nelif args.criterion == \"sare_joint\":\n    criterion_triplet = sare_joint\n\n#### Resume model, optimizer, and other training parameters\nif args.resume:\n    if args.aggregation != 'crn':\n        model, optimizer, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, optimizer)\n    else:\n        # CRN uses pretrained NetVLAD, then requires loading with strict=False and\n        # does not load the optimizer from the checkpoint file.\n        model, _, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, strict=False)\n    logging.info(f\"Resuming from epoch {start_epoch_num} with best recall@5 {best_r5:.1f}\")\nelse:\n    best_r5 = start_epoch_num = not_improved_num = 0\n\nif args.backbone.startswith('vit'):\n    logging.info(f\"Output dimension of the model is {args.features_dim}\")\nelse:\n    logging.info(f\"Output dimension of the model is {args.features_dim}, with {util.get_flops(model, args.resize)}\")\n\n\nif torch.cuda.device_count() >= 2:\n    # When using more than 1GPU, use sync_batchnorm for torch.nn.DataParallel\n    model = convert_model(model)\n    model = model.cuda()\n\n#### Training loop\nfor epoch_num in range(start_epoch_num, args.epochs_num):\n    logging.info(f\"Start training epoch: {epoch_num:02d}\")\n    \n    epoch_start_time = datetime.now()\n    epoch_losses = np.zeros((0, 1), dtype=np.float32)\n    \n    # How many loops should an epoch last (default is 5000/1000=5)\n    loops_num = math.ceil(args.queries_per_epoch / args.cache_refresh_rate)\n    for loop_num in range(loops_num):\n        logging.debug(f\"Cache: {loop_num} / {loops_num}\")\n        \n        # Compute triplets to use in the triplet loss\n        triplets_ds.is_inference = True\n        triplets_ds.compute_triplets(args, model)\n        triplets_ds.is_inference = False\n        \n        triplets_dl = DataLoader(dataset=triplets_ds, num_workers=args.num_workers,\n                                 batch_size=args.train_batch_size,\n                                 collate_fn=datasets_ws.collate_fn,\n                                 pin_memory=(args.device == \"cuda\"),\n                                 drop_last=True)\n        \n        model = model.train()\n        \n        # images shape: (train_batch_size*12)*3*H*W ; by default train_batch_size=4, H=480, W=640\n        # triplets_local_indexes shape: (train_batch_size*10)*3 ; because 10 triplets per query\n        for images, triplets_local_indexes, _ in tqdm(triplets_dl, ncols=100):\n            \n            # Flip all triplets or none\n            if args.horizontal_flip:\n                images = transforms.RandomHorizontalFlip()(images)\n            \n            # Compute features of all images (images contains queries, positives and negatives)\n            features = model(images.to(args.device))\n            loss_triplet = 0\n            \n            if args.criterion == \"triplet\":\n                triplets_local_indexes = torch.transpose(\n                    triplets_local_indexes.view(args.train_batch_size, args.negs_num_per_query, 3), 1, 0)\n                for triplets in triplets_local_indexes:\n                    queries_indexes, positives_indexes, negatives_indexes = triplets.T\n                    loss_triplet += criterion_triplet(features[queries_indexes],\n                                                      features[positives_indexes],\n                                                      features[negatives_indexes])\n            elif args.criterion == 'sare_joint':\n                # sare_joint needs to receive all the negatives at once\n                triplet_index_batch = triplets_local_indexes.view(args.train_batch_size, 10, 3)\n                for batch_triplet_index in triplet_index_batch:\n                    q = features[batch_triplet_index[0, 0]].unsqueeze(0)  # obtain query as tensor of shape 1xn_features\n                    p = features[batch_triplet_index[0, 1]].unsqueeze(0)  # obtain positive as tensor of shape 1xn_features\n                    n = features[batch_triplet_index[:, 2]]               # obtain negatives as tensor of shape 10xn_features\n                    loss_triplet += criterion_triplet(q, p, n)\n            elif args.criterion == \"sare_ind\":\n                for triplet in triplets_local_indexes:\n                    # triplet is a 1-D tensor with the 3 scalars indexes of the triplet\n                    q_i, p_i, n_i = triplet\n                    loss_triplet += criterion_triplet(features[q_i:q_i+1], features[p_i:p_i+1], features[n_i:n_i+1])\n            \n            del features\n            loss_triplet /= (args.train_batch_size * args.negs_num_per_query)\n            \n            optimizer.zero_grad()\n            loss_triplet.backward()\n            optimizer.step()\n            \n            # Keep track of all losses by appending them to epoch_losses\n            batch_loss = loss_triplet.item()\n            epoch_losses = np.append(epoch_losses, batch_loss)\n            del loss_triplet\n        \n        logging.debug(f\"Epoch[{epoch_num:02d}]({loop_num}/{loops_num}): \" +\n                      f\"current batch triplet loss = {batch_loss:.4f}, \" +\n                      f\"average epoch triplet loss = {epoch_losses.mean():.4f}\")\n    \n    logging.info(f\"Finished epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, \"\n                 f\"average epoch triplet loss = {epoch_losses.mean():.4f}\")\n    \n    # Compute recalls on validation set\n    recalls, recalls_str = test.test(args, val_ds, model)\n    logging.info(f\"Recalls on val set {val_ds}: {recalls_str}\")\n    \n    is_best = recalls[1] > best_r5\n    \n    # Save checkpoint, which contains all training parameters\n    util.save_checkpoint(args, {\n        \"epoch_num\": epoch_num, \"model_state_dict\": model.state_dict(),\n        \"optimizer_state_dict\": optimizer.state_dict(), \"recalls\": recalls, \"best_r5\": best_r5,\n        \"not_improved_num\": not_improved_num\n    }, is_best, filename=\"last_model.pth\")\n    \n    # If recall@5 did not improve for \"many\" epochs, stop training\n    if is_best:\n        logging.info(f\"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}\")\n        best_r5 = recalls[1]\n        not_improved_num = 0\n    else:\n        not_improved_num += 1\n        logging.info(f\"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}\")\n        if not_improved_num >= args.patience:\n            logging.info(f\"Performance did not improve for {not_improved_num} epochs. Stop training.\")\n            break\n\n\nlogging.info(f\"Best R@5: {best_r5:.1f}\")\nlogging.info(f\"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}\")\n\n#### Test best model on test set\nbest_model_state_dict = torch.load(join(args.save_dir, \"best_model.pth\"))[\"model_state_dict\"]\nmodel.load_state_dict(best_model_state_dict)\n\nrecalls, recalls_str = test.test(args, test_ds, model, test_method=args.test_method)\nlogging.info(f\"Recalls on {test_ds}: {recalls_str}\")\n"
  },
  {
    "path": "util.py",
    "content": "\nimport re\nimport torch\nimport shutil\nimport logging\nimport torchscan\nimport numpy as np\nfrom collections import OrderedDict\nfrom os.path import join\nfrom sklearn.decomposition import PCA\n\nimport datasets_ws\n\n\ndef get_flops(model, input_shape=(480, 640)):\n    \"\"\"Return the FLOPs as a string, such as '22.33 GFLOPs'\"\"\"\n    assert len(input_shape) == 2, f\"input_shape should have len==2, but it's {input_shape}\"\n    module_info = torchscan.crawl_module(model, (3, input_shape[0], input_shape[1]))\n    output = torchscan.utils.format_info(module_info)\n    return re.findall(\"Floating Point Operations on forward: (.*)\\n\", output)[0]\n\n\ndef save_checkpoint(args, state, is_best, filename):\n    model_path = join(args.save_dir, filename)\n    torch.save(state, model_path)\n    if is_best:\n        shutil.copyfile(model_path, join(args.save_dir, \"best_model.pth\"))\n\n\ndef resume_model(args, model):\n    checkpoint = torch.load(args.resume, map_location=args.device)\n    if 'model_state_dict' in checkpoint:\n        state_dict = checkpoint['model_state_dict']\n    else:\n        # The pre-trained models that we provide in the README do not have 'state_dict' in the keys as\n        # the checkpoint is directly the state dict\n        state_dict = checkpoint\n    # if the model contains the prefix \"module\" which is appendend by\n    # DataParallel, remove it to avoid errors when loading dict\n    if list(state_dict.keys())[0].startswith('module'):\n        state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in state_dict.items()})\n    model.load_state_dict(state_dict)\n    return model\n\n\ndef resume_train(args, model, optimizer=None, strict=False):\n    \"\"\"Load model, optimizer, and other training parameters\"\"\"\n    logging.debug(f\"Loading checkpoint: {args.resume}\")\n    checkpoint = torch.load(args.resume)\n    start_epoch_num = checkpoint[\"epoch_num\"]\n    model.load_state_dict(checkpoint[\"model_state_dict\"], strict=strict)\n    if optimizer:\n        optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n    best_r5 = checkpoint[\"best_r5\"]\n    not_improved_num = checkpoint[\"not_improved_num\"]\n    logging.debug(f\"Loaded checkpoint: start_epoch_num = {start_epoch_num}, \"\n                  f\"current_best_R@5 = {best_r5:.1f}\")\n    if args.resume.endswith(\"last_model.pth\"):  # Copy best model to current save_dir\n        shutil.copy(args.resume.replace(\"last_model.pth\", \"best_model.pth\"), args.save_dir)\n    return model, optimizer, best_r5, start_epoch_num, not_improved_num\n\n\ndef compute_pca(args, model, pca_dataset_folder, full_features_dim):\n    model = model.eval()\n    pca_ds = datasets_ws.PCADataset(args, args.datasets_folder, pca_dataset_folder)\n    dl = torch.utils.data.DataLoader(pca_ds, args.infer_batch_size, shuffle=True)\n    pca_features = np.empty([min(len(pca_ds), 2**14), full_features_dim])\n    with torch.no_grad():\n        for i, images in enumerate(dl):\n            if i*args.infer_batch_size >= len(pca_features):\n                break\n            features = model(images).cpu().numpy()\n            pca_features[i*args.infer_batch_size : (i*args.infer_batch_size)+len(features)] = features\n    pca = PCA(args.pca_dim)\n    pca.fit(pca_features)\n    return pca\n"
  }
]