Repository: gmberton/deep-visual-geo-localization-benchmark
Branch: main
Commit: 4af519437403
Files: 29
Total size: 172.7 KB
Directory structure:
gitextract_q_wormse/
├── .gitignore
├── LICENSE
├── README.md
├── commons.py
├── datasets_ws.py
├── eval.py
├── model/
│ ├── __init__.py
│ ├── aggregation.py
│ ├── cct/
│ │ ├── __init__.py
│ │ ├── cct.py
│ │ ├── embedder.py
│ │ ├── helpers.py
│ │ ├── stochastic_depth.py
│ │ ├── tokenizer.py
│ │ └── transformers.py
│ ├── functional.py
│ ├── network.py
│ ├── normalization.py
│ └── sync_batchnorm/
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── parser.py
├── requirements.txt
├── test.py
├── train.py
└── util.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Autogenerated folders
__pycache__
logs
test
data
# IDEs generated folders
.spyproject
venv/
.idea/
__MACOSX/
**/.DS_Store
# other
pretrained
*.pth
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2016-2019 VRG, CTU Prague
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Deep Visual Geo-localization Benchmark
This is the official repository for the CVPR 2022 Oral paper [Deep Visual Geo-localization Benchmark](https://arxiv.org/abs/2204.03444).
It can be used to reproduce results from the paper, and to compute a wide range of experiments, by changing the components of a Visual Geo-localization pipeline.
## Setup
Before you begin experimenting with this toolbox, your dataset should be organized in a directory tree as such:
```
.
├── benchmarking_vg
└── datasets_vg
└── datasets
└── pitts30k
└── images
├── train
│ ├── database
│ └── queries
├── val
│ ├── database
│ └── queries
└── test
├── database
└── queries
```
The [VPR-datasets-downloader](https://github.com/gmberton/VPR-datasets-downloader) repo can be used to download a number of datasets. Detailed instructions on how to download datasets are in the repo. Note that many datasets are available, and _pitts30k_ is just an example.
## Running experiments
### Basic experiment
For a basic experiment run
`$ python3 train.py --dataset_name=pitts30k`
this will train a ResNet-18 + NetVLAD on Pitts30k.
The experiment creates a folder named `./logs/default/YYYY-MM-DD_HH-mm-ss`, where checkpoints are saved, as well as an `info.log` file with training logs and other information, such as model size, FLOPs and descriptors dimensionality.
### Architectures and mining
You can replace the backbone and the aggregation as such
`$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem`
you can easily use ResNets cropped at conv4 or conv5.
#### Add a fully connected layer
To add a fully connected layer of dimension 2048 to GeM pooling:
`$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem --fc_output_dim=2048`
#### Add PCA
To add PCA to a NetVLAD layer just do:
`$ python3 eval.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=netvlad --pca_dim=2048 --pca_dataset_folder=pitts30k/images/train`
where _pca_dataset_folder_ points to the folder with the images used to compute PCA. In the paper we compute PCA's principal components on the train set as it showed best results. PCA is used only at test time.
#### Evaluate trained models
To evaluate the trained model on other datasets (this example is with the St Lucia dataset), simply run
`$ python3 eval.py --backbone=resnet50conv4 --aggregation=gem --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=st_lucia`
#### Reproduce the results
Finally, to reproduce our results, use the appropriate mining method: _full_ for _pitts30k_ and _partial_ for _msls_ as such:
`$ python3 train.py --dataset_name=pitts30k --mining=full`
As simple as this, you can replicate all results from tables 3, 4, 5 of the main paper, as well as tables 2, 3, 4 of the supplementary.
### Resize
To resize the images simply pass the parameters _resize_ with the target resolution. For example, 80% of resolution to the full _pitts30k_ images, would be 384, 512, because the full images are 480, 640:
`$ python3 train.py --dataset_name=pitts30k --resize=384 512`
### Query pre/post-processing and predictions refinement
We gather all such methods under the _test_method_ parameter. The available methods are _hard_resize_, _single_query_, _central_crop_, _five_crops_mean_, _nearest_crop_ and _majority_voting_.
Although _hard_resize_ is the default, in most datasets it doesn't apply any transformation at all (see the paper for more information), because all images have the same resolution.
`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=tokyo247 --test_method=nearest_crop`
### Data augmentation
You can reproduce all data augmentation techniques from the paper with simple commands, for example:
`$ python3 train.py --dataset_name=pitts30k --horizontal_flipping --saturation 2 --brightness 1`
### Off-the-shelf models trained on Landmark Recognition datasets
The code allows to automatically download and use models trained on Landmark Recognition datasets from popular repositories: [radenovic](https://github.com/filipradenovic/cnnimageretrieval-pytorch) and [naver](https://github.com/naver/deep-image-retrieval).
These repos offer ResNets-50/101 with GeM and FC 2048 trained on such datasets, and can be used as such:
`$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=r101l4 --aggregation=gem --fc_output_dim=2048`
`$ python eval.py --dataset_name=pitts30k --off_the_shelf=naver --l2=none --backbone=r101l4 --aggregation=gem --fc_output_dim=2048`
### Using pretrained networks on other datasets
Check out our [pretrain_vg](https://github.com/rm-wu/pretrain_vg) repo which we use to train such models.
You can automatically download and train on those models as such
`$ python train.py --dataset_name=pitts30k --pretrained=places`
### Changing the threshold distance
You can use a different distance than the default 25 meters as simply as this (for example to 100 meters):
`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --val_positive_dist_threshold=100`
### Changing the recall values (R@N)
By default the toolbox computes recalls@ 1, 5, 10, 20, but you can compute other recalls as such:
`$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --recall_values 1 5 10 15 20 50 100`
### Model Zoo
We are currently exploring hosting options, so this is a partial list of models. More models will be added soon!!
Pretrained models with different backbones
Pretained networks employing different backbones.
| Model |
Training on Pitts30k |
Training on MSLS |
| Pitts30k (R@1) |
MSLS (R@1) |
Download |
Pitts30k (R@1) |
MSLS (R@1) |
Download |
| vgg16-gem |
78.5 | 43.4 |
[Link] |
70.2 | 66.7 |
[Link] |
| resnet18-gem |
77.8 | 35.3 |
[Link] |
71.6 | 65.3 |
[Link] |
| resnet50-gem |
82.0 | 38.0 |
[Link] |
77.4 | 72.0 |
[Link] |
| resnet101-gem |
82.4 | 39.6 |
[Link] |
77.2 | 72.5 |
[Link] |
| ViT(224)-CLS |
_ | _ |
_ |
80.4 | 69.3 |
[Link] |
| vgg16-netvlad |
83.2 | 50.9 |
[Link] |
79.0 | 74.6 |
[Link]
|
| resnet18-netvlad |
86.4 | 47.4 |
[Link] |
81.6 | 75.8 |
[Link] |
| resnet50-netvlad |
86.0 | 50.7 |
[Link] |
80.9 | 76.9 |
[Link] |
| resnet101-netvlad |
86.5 | 51.8 |
[Link] |
80.8 | 77.7 |
[Link] |
| cct384-netvlad |
85.0 | 52.5 |
[Link] |
80.3 | 85.1 |
[Link] |
Pretrained models with different aggregation methods
Pretrained networks trained using different aggregation methods.
| Model |
Training on Pitts30k (R@1) |
Training on MSLS (R@1) |
| Pitts30k (R@1) |
MSLS (R@1) |
Download |
Pitts30k (R@1) |
MSLS (R@1) |
Download |
| resnet50-gem |
82.0 | 38.0 |
[Link] |
77.4 | 72.0 |
[Link] |
| resnet50-gem-fc2048 |
80.1 | 33.7 |
[Link] |
79.2 | 73.5 |
[Link] |
| resnet50-gem-fc65536 |
80.8 | 35.8 |
[Link] |
79.0 | 74.4 |
[Link] |
| resnet50-netvlad |
86.0 | 50.7 |
[Link] |
80.9 | 76.9 |
[Link] |
| resnet50-crn |
85.8 | 54.0 |
[Link] |
80.8 | 77.8 |
[Link] |
Pretrained models with different mining methods
Pretained networks trained using three different mining methods (random, full database mining and partial database mining):
| Model |
Training on Pitts30k (R@1) |
Training on MSLS (R@1) |
| Pitts30k (R@1) |
MSLS (R@1) |
Download |
Pitts30k (R@1) |
MSLS (R@1) |
Download |
| resnet18-gem-random |
73.7 | 30.5 |
[Link] |
62.2 | 50.6 |
[Link] |
| resnet18-gem-full |
77.8 | 35.3 |
[Link] |
70.1 | 61.8 |
[Link] |
| resnet18-gem-partial |
76.5 | 34.2 |
[Link] |
71.6 | 65.3 |
[Link] |
| resnet18-netvlad-random |
83.9 | 43.6 |
[Link] |
73.3 | 61.5 |
[Link] |
| resnet18-netvlad-full |
86.4 | 47.4 |
[Link] |
- | - |
- |
| resnet18-netvlad-partial |
86.2 | 47.3 |
[Link] |
81.6 | 75.8 |
[Link] |
| resnet50-gem-random |
77.9 | 34.3 |
[Link] |
69.5 | 57.4 |
[Link] |
| resnet50-gem-full |
82.0 | 38.0 |
[Link] |
77.3 | 69.7 |
[Link] |
| resnet50-gem-partial |
82.3 | 39.0 |
[Link] |
77.4 | 72.0 |
[Link] |
| resnet50-netvlad-random |
83.4 | 45.0 |
[Link] |
74.9 | 63.6 |
[Link] |
| resnet50-netvlad-full |
86.0 | 50.7 |
[Link] |
- | - |
- |
| resnet50-netvlad-partial |
85.5 | 48.6 |
[Link] |
80.9 | 76.9 |
[Link] |
If you find our work useful in your research please consider citing our paper:
```bibtex
@inproceedings{Berton_CVPR_2022_benchmark,
author = {Berton, Gabriele and Mereu, Riccardo and Trivigno, Gabriele and Masone, Carlo and Csurka, Gabriela and Sattler, Torsten and Caputo, Barbara},
title = {Deep Visual Geo-Localization Benchmark},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition},
month = {June},
year = {2022}
}
```
## Acknowledgements
Parts of this repo are inspired by the following great repositories:
- [NetVLAD's original code](https://github.com/Relja/netvlad) (in MATLAB)
- [NetVLAD layer in PyTorch](https://github.com/lyakaap/NetVLAD-pytorch)
- [NetVLAD training in PyTorch](https://github.com/Nanne/pytorch-NetVlad/)
- [GeM layer](https://github.com/filipradenovic/cnnimageretrieval-pytorch)
- [Deep Image Retrieval](https://github.com/naver/deep-image-retrieval)
- [Mapillary Street-level Sequences](https://github.com/mapillary/mapillary_sls)
- [Compact Convolutional Transformers](https://github.com/SHI-Labs/Compact-Transformers)
Check out also our other repo [_CosPlace_](https://github.com/gmberton/CosPlace), from the CVPR 2022 paper "Rethinking Visual Geo-localization for Large-Scale Applications", which provides a new SOTA in visual geo-localization / visual place recognition.
================================================
FILE: commons.py
================================================
"""
This file contains some functions and classes which can be useful in very diverse projects.
"""
import os
import sys
import torch
import random
import logging
import traceback
import numpy as np
from os.path import join
def make_deterministic(seed=0):
"""Make results deterministic. If seed == -1, do not make deterministic.
Running the script in a deterministic way might slow it down.
"""
if seed == -1:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup_logging(save_dir, console="debug",
info_filename="info.log", debug_filename="debug.log"):
"""Set up logging files and console output.
Creates one file for INFO logs and one for DEBUG logs.
Args:
save_dir (str): creates the folder where to save the files.
debug (str):
if == "debug" prints on console debug messages and higher
if == "info" prints on console info messages and higher
if == None does not use console (useful when a logger has already been set)
info_filename (str): the name of the info file. if None, don't create info file
debug_filename (str): the name of the debug file. if None, don't create debug file
"""
if os.path.exists(save_dir):
raise FileExistsError(f"{save_dir} already exists!")
os.makedirs(save_dir, exist_ok=True)
# logging.Logger.manager.loggerDict.keys() to check which loggers are in use
base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
logger = logging.getLogger('')
logger.setLevel(logging.DEBUG)
if info_filename is not None:
info_file_handler = logging.FileHandler(join(save_dir, info_filename))
info_file_handler.setLevel(logging.INFO)
info_file_handler.setFormatter(base_formatter)
logger.addHandler(info_file_handler)
if debug_filename is not None:
debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))
debug_file_handler.setLevel(logging.DEBUG)
debug_file_handler.setFormatter(base_formatter)
logger.addHandler(debug_file_handler)
if console is not None:
console_handler = logging.StreamHandler()
if console == "debug":
console_handler.setLevel(logging.DEBUG)
if console == "info":
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(base_formatter)
logger.addHandler(console_handler)
def exception_handler(type_, value, tb):
logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
sys.excepthook = exception_handler
================================================
FILE: datasets_ws.py
================================================
import os
import torch
import faiss
import logging
import numpy as np
from glob import glob
from tqdm import tqdm
from PIL import Image
from os.path import join
import torch.utils.data as data
import torchvision.transforms as T
from torch.utils.data.dataset import Subset
from sklearn.neighbors import NearestNeighbors
from torch.utils.data.dataloader import DataLoader
base_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def path_to_pil_img(path):
return Image.open(path).convert("RGB")
def collate_fn(batch):
"""Creates mini-batch tensors from the list of tuples (images,
triplets_local_indexes, triplets_global_indexes).
triplets_local_indexes are the indexes referring to each triplet within images.
triplets_global_indexes are the global indexes of each image.
Args:
batch: list of tuple (images, triplets_local_indexes, triplets_global_indexes).
considering each query to have 10 negatives (negs_num_per_query=10):
- images: torch tensor of shape (12, 3, h, w).
- triplets_local_indexes: torch tensor of shape (10, 3).
- triplets_global_indexes: torch tensor of shape (12).
Returns:
images: torch tensor of shape (batch_size*12, 3, h, w).
triplets_local_indexes: torch tensor of shape (batch_size*10, 3).
triplets_global_indexes: torch tensor of shape (batch_size, 12).
"""
images = torch.cat([e[0] for e in batch])
triplets_local_indexes = torch.cat([e[1][None] for e in batch])
triplets_global_indexes = torch.cat([e[2][None] for e in batch])
for i, (local_indexes, global_indexes) in enumerate(zip(triplets_local_indexes, triplets_global_indexes)):
local_indexes += len(global_indexes) * i # Increment local indexes by offset (len(global_indexes) is 12)
return images, torch.cat(tuple(triplets_local_indexes)), triplets_global_indexes
class PCADataset(data.Dataset):
def __init__(self, args, datasets_folder="dataset", dataset_folder="pitts30k/images/train"):
dataset_folder_full_path = join(datasets_folder, dataset_folder)
if not os.path.exists(dataset_folder_full_path):
raise FileNotFoundError(f"Folder {dataset_folder_full_path} does not exist")
self.images_paths = sorted(glob(join(dataset_folder_full_path, "**", "*.jpg"), recursive=True))
def __getitem__(self, index):
return base_transform(path_to_pil_img(self.images_paths[index]))
def __len__(self):
return len(self.images_paths)
class BaseDataset(data.Dataset):
"""Dataset with images from database and queries, used for inference (testing and building cache).
"""
def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train"):
super().__init__()
self.args = args
self.dataset_name = dataset_name
self.dataset_folder = join(datasets_folder, dataset_name, "images", split)
if not os.path.exists(self.dataset_folder):
raise FileNotFoundError(f"Folder {self.dataset_folder} does not exist")
self.resize = args.resize
self.test_method = args.test_method
#### Read paths and UTM coordinates for all images.
database_folder = join(self.dataset_folder, "database")
queries_folder = join(self.dataset_folder, "queries")
if not os.path.exists(database_folder):
raise FileNotFoundError(f"Folder {database_folder} does not exist")
if not os.path.exists(queries_folder):
raise FileNotFoundError(f"Folder {queries_folder} does not exist")
self.database_paths = sorted(glob(join(database_folder, "**", "*.jpg"), recursive=True))
self.queries_paths = sorted(glob(join(queries_folder, "**", "*.jpg"), recursive=True))
# The format must be path/to/file/@utm_easting@utm_northing@...@.jpg
self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float)
# Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters)
knn = NearestNeighbors(n_jobs=-1)
knn.fit(self.database_utms)
self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms,
radius=args.val_positive_dist_threshold,
return_distance=False)
self.images_paths = list(self.database_paths) + list(self.queries_paths)
self.database_num = len(self.database_paths)
self.queries_num = len(self.queries_paths)
def __getitem__(self, index):
img = path_to_pil_img(self.images_paths[index])
img = base_transform(img)
# With database images self.test_method should always be "hard_resize"
if self.test_method == "hard_resize":
# self.test_method=="hard_resize" is the default, resizes all images to the same size.
img = T.functional.resize(img, self.resize)
else:
img = self._test_query_transform(img)
return img, index
def _test_query_transform(self, img):
"""Transform query image according to self.test_method."""
C, H, W = img.shape
if self.test_method == "single_query":
# self.test_method=="single_query" is used when queries have varying sizes, and can't be stacked in a batch.
processed_img = T.functional.resize(img, min(self.resize))
elif self.test_method == "central_crop":
# Take the biggest central crop of size self.resize. Preserves ratio.
scale = max(self.resize[0]/H, self.resize[1]/W)
processed_img = torch.nn.functional.interpolate(img.unsqueeze(0), scale_factor=scale).squeeze(0)
processed_img = T.functional.center_crop(processed_img, self.resize)
assert processed_img.shape[1:] == torch.Size(self.resize), f"{processed_img.shape[1:]} {self.resize}"
elif self.test_method == "five_crops" or self.test_method == 'nearest_crop' or self.test_method == 'maj_voting':
# Get 5 square crops with size==shorter_side (usually 480). Preserves ratio and allows batches.
shorter_side = min(self.resize)
processed_img = T.functional.resize(img, shorter_side)
processed_img = torch.stack(T.functional.five_crop(processed_img, shorter_side))
assert processed_img.shape == torch.Size([5, 3, shorter_side, shorter_side]), \
f"{processed_img.shape} {torch.Size([5, 3, shorter_side, shorter_side])}"
return processed_img
def __len__(self):
return len(self.images_paths)
def __repr__(self):
return f"< {self.__class__.__name__}, {self.dataset_name} - #database: {self.database_num}; #queries: {self.queries_num} >"
def get_positives(self):
return self.soft_positives_per_query
class TripletsDataset(BaseDataset):
"""Dataset used for training, it is used to compute the triplets
with TripletsDataset.compute_triplets() with various mining methods.
If is_inference == True, uses methods of the parent class BaseDataset,
this is used for example when computing the cache, because we compute features
of each image, not triplets.
"""
def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train", negs_num_per_query=10):
super().__init__(args, datasets_folder, dataset_name, split)
self.mining = args.mining
self.neg_samples_num = args.neg_samples_num # Number of negatives to randomly sample
self.negs_num_per_query = negs_num_per_query # Number of negatives per query in each batch
if self.mining == "full": # "Full database mining" keeps a cache with last used negatives
self.neg_cache = [np.empty((0,), dtype=np.int32) for _ in range(self.queries_num)]
self.is_inference = False
identity_transform = T.Lambda(lambda x: x)
self.resized_transform = T.Compose([
T.Resize(self.resize) if self.resize is not None else identity_transform,
base_transform
])
self.query_transform = T.Compose([
T.ColorJitter(args.brightness, args.contrast, args.saturation, args.hue),
T.RandomPerspective(args.rand_perspective),
T.RandomResizedCrop(size=self.resize, scale=(1-args.random_resized_crop, 1)),
T.RandomRotation(degrees=args.random_rotation),
self.resized_transform,
])
# Find hard_positives_per_query, which are within train_positives_dist_threshold (10 meters)
knn = NearestNeighbors(n_jobs=-1)
knn.fit(self.database_utms)
self.hard_positives_per_query = list(knn.radius_neighbors(self.queries_utms,
radius=args.train_positives_dist_threshold, # 10 meters
return_distance=False))
#### Some queries might have no positive, we should remove those queries.
queries_without_any_hard_positive = np.where(np.array([len(p) for p in self.hard_positives_per_query], dtype=object) == 0)[0]
if len(queries_without_any_hard_positive) != 0:
logging.info(f"There are {len(queries_without_any_hard_positive)} queries without any positives " +
"within the training set. They won't be considered as they're useless for training.")
# Remove queries without positives
self.hard_positives_per_query = np.delete(self.hard_positives_per_query, queries_without_any_hard_positive)
self.soft_positives_per_query = np.delete(self.soft_positives_per_query, queries_without_any_hard_positive)
self.queries_paths = np.delete(self.queries_paths, queries_without_any_hard_positive)
# Recompute images_paths and queries_num because some queries might have been removed
self.images_paths = list(self.database_paths) + list(self.queries_paths)
self.queries_num = len(self.queries_paths)
# msls_weighted refers to the mining presented in MSLS paper's supplementary.
# Basically, images from uncommon domains are sampled more often. Works only with MSLS dataset.
if self.mining == "msls_weighted":
notes = [p.split("@")[-2] for p in self.queries_paths]
try:
night_indexes = np.where(np.array([n.split("_")[0] == "night" for n in notes]))[0]
sideways_indexes = np.where(np.array([n.split("_")[1] == "sideways" for n in notes]))[0]
except IndexError:
raise RuntimeError("You're using msls_weighted mining but this dataset " +
"does not have night/sideways information. Are you using Mapillary SLS?")
self.weights = np.ones(self.queries_num)
assert len(night_indexes) != 0 and len(sideways_indexes) != 0, \
"There should be night and sideways images for msls_weighted mining, but there are none. Are you using Mapillary SLS?"
self.weights[night_indexes] += self.queries_num / len(night_indexes)
self.weights[sideways_indexes] += self.queries_num / len(sideways_indexes)
self.weights /= self.weights.sum()
logging.info(f"#sideways_indexes [{len(sideways_indexes)}/{self.queries_num}]; " +
"#night_indexes; [{len(night_indexes)}/{self.queries_num}]")
def __getitem__(self, index):
if self.is_inference:
# At inference time return the single image. This is used for caching or computing NetVLAD's clusters
return super().__getitem__(index)
query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index], (1, 1, self.negs_num_per_query))
query = self.query_transform(path_to_pil_img(self.queries_paths[query_index]))
positive = self.resized_transform(path_to_pil_img(self.database_paths[best_positive_index]))
negatives = [self.resized_transform(path_to_pil_img(self.database_paths[i])) for i in neg_indexes]
images = torch.stack((query, positive, *negatives), 0)
triplets_local_indexes = torch.empty((0, 3), dtype=torch.int)
for neg_num in range(len(neg_indexes)):
triplets_local_indexes = torch.cat((triplets_local_indexes, torch.tensor([0, 1, 2 + neg_num]).reshape(1, 3)))
return images, triplets_local_indexes, self.triplets_global_indexes[index]
def __len__(self):
if self.is_inference:
# At inference time return the number of images. This is used for caching or computing NetVLAD's clusters
return super().__len__()
else:
return len(self.triplets_global_indexes)
def compute_triplets(self, args, model):
self.is_inference = True
if self.mining == "full":
self.compute_triplets_full(args, model)
elif self.mining == "partial" or self.mining == "msls_weighted":
self.compute_triplets_partial(args, model)
elif self.mining == "random":
self.compute_triplets_random(args, model)
@staticmethod
def compute_cache(args, model, subset_ds, cache_shape):
"""Compute the cache containing features of images, which is used to
find best positive and hardest negatives."""
subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, shuffle=False,
pin_memory=(args.device == "cuda"))
model = model.eval()
# RAMEfficient2DMatrix can be replaced by np.zeros, but using
# RAMEfficient2DMatrix is RAM efficient for full database mining.
cache = RAMEfficient2DMatrix(cache_shape, dtype=np.float32)
with torch.no_grad():
for images, indexes in tqdm(subset_dl, ncols=100):
images = images.to(args.device)
features = model(images)
cache[indexes.numpy()] = features.cpu().numpy()
return cache
def get_query_features(self, query_index, cache):
query_features = cache[query_index + self.database_num]
if query_features is None:
raise RuntimeError(f"For query {self.queries_paths[query_index]} " +
f"with index {query_index} features have not been computed!\n" +
"There might be some bug with caching")
return query_features
def get_best_positive_index(self, args, query_index, cache, query_features):
positives_features = cache[self.hard_positives_per_query[query_index]]
faiss_index = faiss.IndexFlatL2(args.features_dim)
faiss_index.add(positives_features)
# Search the best positive (within 10 meters AND nearest in features space)
_, best_positive_num = faiss_index.search(query_features.reshape(1, -1), 1)
best_positive_index = self.hard_positives_per_query[query_index][best_positive_num[0]].item()
return best_positive_index
def get_hardest_negatives_indexes(self, args, cache, query_features, neg_samples):
neg_features = cache[neg_samples]
faiss_index = faiss.IndexFlatL2(args.features_dim)
faiss_index.add(neg_features)
# Search the 10 nearest negatives (further than 25 meters and nearest in features space)
_, neg_nums = faiss_index.search(query_features.reshape(1, -1), self.negs_num_per_query)
neg_nums = neg_nums.reshape(-1)
neg_indexes = neg_samples[neg_nums].astype(np.int32)
return neg_indexes
def compute_triplets_random(self, args, model):
self.triplets_global_indexes = []
# Take 1000 random queries
sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)
# Take all the positives
positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes]
positives_indexes = [p for pos in positives_indexes for p in pos] # Flatten list of lists to a list
positives_indexes = list(np.unique(positives_indexes))
# Compute the cache only for queries and their positives, in order to find the best positive
subset_ds = Subset(self, positives_indexes + list(sampled_queries_indexes + self.database_num))
cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim))
# This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)
for query_index in tqdm(sampled_queries_indexes, ncols=100):
query_features = self.get_query_features(query_index, cache)
best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)
# Choose some random database images, from those remove the soft_positives, and then take the first 10 images as neg_indexes
soft_positives = self.soft_positives_per_query[query_index]
neg_indexes = np.random.choice(self.database_num, size=self.negs_num_per_query+len(soft_positives), replace=False)
neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True)[:self.negs_num_per_query]
self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))
# self.triplets_global_indexes is a tensor of shape [1000, 12]
self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)
def compute_triplets_full(self, args, model):
self.triplets_global_indexes = []
# Take 1000 random queries
sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)
# Take all database indexes
database_indexes = list(range(self.database_num))
# Compute features for all images and store them in cache
subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num))
cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim))
# This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)
for query_index in tqdm(sampled_queries_indexes, ncols=100):
query_features = self.get_query_features(query_index, cache)
best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)
# Choose 1000 random database images (neg_indexes)
neg_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False)
# Remove the eventual soft_positives from neg_indexes
soft_positives = self.soft_positives_per_query[query_index]
neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True)
# Concatenate neg_indexes with the previous top 10 negatives (neg_cache)
neg_indexes = np.unique(np.concatenate([self.neg_cache[query_index], neg_indexes]))
# Search the hardest negatives
neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes)
# Update nearest negatives in neg_cache
self.neg_cache[query_index] = neg_indexes
self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))
# self.triplets_global_indexes is a tensor of shape [1000, 12]
self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)
def compute_triplets_partial(self, args, model):
self.triplets_global_indexes = []
# Take 1000 random queries
if self.mining == "partial":
sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False)
elif self.mining == "msls_weighted": # Pick night and sideways queries with higher probability
sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False, p=self.weights)
# Sample 1000 random database images for the negatives
sampled_database_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False)
# Take all the positives
positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes]
positives_indexes = [p for pos in positives_indexes for p in pos]
# Merge them into database_indexes and remove duplicates
database_indexes = list(sampled_database_indexes) + positives_indexes
database_indexes = list(np.unique(database_indexes))
subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num))
cache = self.compute_cache(args, model, subset_ds, cache_shape=(len(self), args.features_dim))
# This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results)
for query_index in tqdm(sampled_queries_indexes, ncols=100):
query_features = self.get_query_features(query_index, cache)
best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features)
# Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives
soft_positives = self.soft_positives_per_query[query_index]
neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True)
# Take all database images that are negatives and are within the sampled database images (aka database_indexes)
neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes)
self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes))
# self.triplets_global_indexes is a tensor of shape [1000, 12]
self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes)
class RAMEfficient2DMatrix:
"""This class behaves similarly to a numpy.ndarray initialized
with np.zeros(), but is implemented to save RAM when the rows
within the 2D array are sparse. In this case it's needed because
we don't always compute features for each image, just for few of
them"""
def __init__(self, shape, dtype=np.float32):
self.shape = shape
self.dtype = dtype
self.matrix = [None] * shape[0]
def __setitem__(self, indexes, vals):
assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}"
for i, val in zip(indexes, vals):
self.matrix[i] = val.astype(self.dtype, copy=False)
def __getitem__(self, index):
if hasattr(index, "__len__"):
return np.array([self.matrix[i] for i in index])
else:
return self.matrix[index]
================================================
FILE: eval.py
================================================
"""
With this script you can evaluate checkpoints or test models from two popular
landmark retrieval github repos.
The first is https://github.com/naver/deep-image-retrieval from Naver labs,
provides ResNet-50 and ResNet-101 trained with AP on Google Landmarks 18 clean.
$ python eval.py --off_the_shelf=naver --l2=none --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
The second is https://github.com/filipradenovic/cnnimageretrieval-pytorch from
Radenovic, provides ResNet-50 and ResNet-101 trained with a triplet loss
on Google Landmarks 18 and sfm120k.
$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
$ python eval.py --off_the_shelf=radenovic_sfm --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048
Note that although the architectures are almost the same, Naver's
implementation does not use a l2 normalization before/after the GeM aggregation,
while Radenovic's uses it after (and we use it before, which shows better
results in VG)
"""
import os
import sys
import torch
import parser
import logging
import sklearn
from os.path import join
from datetime import datetime
from torch.utils.model_zoo import load_url
from google_drive_downloader import GoogleDriveDownloader as gdd
import test
import util
import commons
import datasets_ws
from model import network
OFF_THE_SHELF_RADENOVIC = {
'resnet50conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
'resnet101conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
'resnet50conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
'resnet101conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
}
OFF_THE_SHELF_NAVER = {
"resnet50conv5" : "1oPtE_go9tnsiDLkWjN4NMpKjh-_md1G5",
'resnet101conv5' : "1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy"
}
######################################### SETUP #########################################
args = parser.parse_arguments()
start_time = datetime.now()
args.save_dir = join("test", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S'))
commons.setup_logging(args.save_dir)
commons.make_deterministic(args.seed)
logging.info(f"Arguments: {args}")
logging.info(f"The outputs are being saved in {args.save_dir}")
######################################### MODEL #########################################
model = network.GeoLocalizationNet(args)
model = model.to(args.device)
if args.aggregation in ["netvlad", "crn"]:
args.features_dim *= args.netvlad_clusters
if args.off_the_shelf.startswith("radenovic") or args.off_the_shelf.startswith("naver"):
if args.off_the_shelf.startswith("radenovic"):
pretrain_dataset_name = args.off_the_shelf.split("_")[1] # sfm or gldv1 datasets
url = OFF_THE_SHELF_RADENOVIC[f"{args.backbone}_{pretrain_dataset_name}"]
state_dict = load_url(url, model_dir=join("data", "off_the_shelf_nets"))
else:
# This is a hacky workaround to maintain compatibility
sys.modules['sklearn.decomposition.pca'] = sklearn.decomposition._pca
zip_file_path = join("data", "off_the_shelf_nets", args.backbone + "_naver.zip")
if not os.path.exists(zip_file_path):
gdd.download_file_from_google_drive(file_id=OFF_THE_SHELF_NAVER[args.backbone],
dest_path=zip_file_path, unzip=True)
if args.backbone == "resnet50conv5":
state_dict_filename = "Resnet50-AP-GeM.pt"
elif args.backbone == "resnet101conv5":
state_dict_filename = "Resnet-101-AP-GeM.pt"
state_dict = torch.load(join("data", "off_the_shelf_nets", state_dict_filename))
state_dict = state_dict["state_dict"]
model_keys = model.state_dict().keys()
renamed_state_dict = {k: v for k, v in zip(model_keys, state_dict.values())}
model.load_state_dict(renamed_state_dict)
elif args.resume is not None:
logging.info(f"Resuming model from {args.resume}")
model = util.resume_model(args, model)
# Enable DataParallel after loading checkpoint, otherwise doing it before
# would append "module." in front of the keys of the state dict triggering errors
model = torch.nn.DataParallel(model)
if args.pca_dim is None:
pca = None
else:
full_features_dim = args.features_dim
args.features_dim = args.pca_dim
pca = util.compute_pca(args, model, args.pca_dataset_folder, full_features_dim)
######################################### DATASETS #########################################
test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test")
logging.info(f"Test set: {test_ds}")
######################################### TEST on TEST SET #########################################
recalls, recalls_str = test.test(args, test_ds, model, args.test_method, pca)
logging.info(f"Recalls on {test_ds}: {recalls_str}")
logging.info(f"Finished in {str(datetime.now() - start_time)[:-7]}")
================================================
FILE: model/__init__.py
================================================
================================================
FILE: model/aggregation.py
================================================
import math
import torch
import faiss
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, SubsetRandomSampler
import model.functional as LF
import model.normalization as normalization
class MAC(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return LF.mac(x)
def __repr__(self):
return self.__class__.__name__ + '()'
class SPoC(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return LF.spoc(x)
def __repr__(self):
return self.__class__.__name__ + '()'
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6, work_with_tokens=False):
super().__init__()
self.p = Parameter(torch.ones(1)*p)
self.eps = eps
self.work_with_tokens=work_with_tokens
def forward(self, x):
return LF.gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens)
def __repr__(self):
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
class RMAC(nn.Module):
def __init__(self, L=3, eps=1e-6):
super().__init__()
self.L = L
self.eps = eps
def forward(self, x):
return LF.rmac(x, L=self.L, eps=self.eps)
def __repr__(self):
return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'
class Flatten(torch.nn.Module):
def __init__(self): super().__init__()
def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0]
class RRM(nn.Module):
"""Residual Retrieval Module as described in the paper
`Leveraging EfficientNet and Contrastive Learning for AccurateGlobal-scale
Location Estimation `
"""
def __init__(self, dim):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = Flatten()
self.ln1 = nn.LayerNorm(normalized_shape=dim)
self.fc1 = nn.Linear(in_features=dim, out_features=dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(in_features=dim, out_features=dim)
self.ln2 = nn.LayerNorm(normalized_shape=dim)
self.l2 = normalization.L2Norm()
def forward(self, x):
x = self.avgpool(x)
x = self.flatten(x)
x = self.ln1(x)
identity = x
out = self.fc2(self.relu(self.fc1(x)))
out += identity
out = self.l2(self.ln2(out))
return out
# based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py
class NetVLAD(nn.Module):
"""NetVLAD layer implementation"""
def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False):
"""
Args:
clusters_num : int
The number of clusters
dim : int
Dimension of descriptors
alpha : float
Parameter of initialization. Larger value is harder assignment.
normalize_input : bool
If true, descriptor-wise L2 normalization is applied to input.
"""
super().__init__()
self.clusters_num = clusters_num
self.dim = dim
self.alpha = 0
self.normalize_input = normalize_input
self.work_with_tokens = work_with_tokens
if work_with_tokens:
self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False)
else:
self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False)
self.centroids = nn.Parameter(torch.rand(clusters_num, dim))
def init_params(self, centroids, descriptors):
centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True)
dots = np.dot(centroids_assign, descriptors.T)
dots.sort(0)
dots = dots[::-1, :] # sort, descending
self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item()
self.centroids = nn.Parameter(torch.from_numpy(centroids))
if self.work_with_tokens:
self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2))
else:
self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3))
self.conv.bias = None
def forward(self, x):
if self.work_with_tokens:
x = x.permute(0, 2, 1)
N, D, _ = x.shape[:]
else:
N, D, H, W = x.shape[:]
if self.normalize_input:
x = F.normalize(x, p=2, dim=1) # Across descriptor dim
x_flatten = x.view(N, D, -1)
soft_assign = self.conv(x).view(N, self.clusters_num, -1)
soft_assign = F.softmax(soft_assign, dim=1)
vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)
for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2)
vlad[:,D:D+1,:] = residual.sum(dim=-1)
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
vlad = vlad.view(N, -1) # Flatten
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
return vlad
def initialize_netvlad_layer(self, args, cluster_ds, backbone):
descriptors_num = 50000
descs_num_per_image = 100
images_num = math.ceil(descriptors_num / descs_num_per_image)
random_sampler = SubsetRandomSampler(np.random.choice(len(cluster_ds), images_num, replace=False))
random_dl = DataLoader(dataset=cluster_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, sampler=random_sampler)
with torch.no_grad():
backbone = backbone.eval()
logging.debug("Extracting features to initialize NetVLAD layer")
descriptors = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32)
for iteration, (inputs, _) in enumerate(tqdm(random_dl, ncols=100)):
inputs = inputs.to(args.device)
outputs = backbone(inputs)
norm_outputs = F.normalize(outputs, p=2, dim=1)
image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1)
image_descriptors = image_descriptors.cpu().numpy()
batchix = iteration * args.infer_batch_size * descs_num_per_image
for ix in range(image_descriptors.shape[0]):
sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False)
startix = batchix + ix * descs_num_per_image
descriptors[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :]
kmeans = faiss.Kmeans(args.features_dim, self.clusters_num, niter=100, verbose=False)
kmeans.train(descriptors)
logging.debug(f"NetVLAD centroids shape: {kmeans.centroids.shape}")
self.init_params(kmeans.centroids, descriptors)
self = self.to(args.device)
class CRNModule(nn.Module):
def __init__(self, dim):
super().__init__()
# Downsample pooling
self.downsample_pool = nn.AvgPool2d(kernel_size=3, stride=(2, 2),
padding=0, ceil_mode=True)
# Multiscale Context Filters
self.filter_3_3 = nn.Conv2d(in_channels=dim, out_channels=32,
kernel_size=(3, 3), padding=1)
self.filter_5_5 = nn.Conv2d(in_channels=dim, out_channels=32,
kernel_size=(5, 5), padding=2)
self.filter_7_7 = nn.Conv2d(in_channels=dim, out_channels=20,
kernel_size=(7, 7), padding=3)
# Accumulation weight
self.acc_w = nn.Conv2d(in_channels=84, out_channels=1, kernel_size=(1, 1))
# Upsampling
self.upsample = F.interpolate
self._initialize_weights()
def _initialize_weights(self):
# Initialize Context Filters
torch.nn.init.xavier_normal_(self.filter_3_3.weight)
torch.nn.init.constant_(self.filter_3_3.bias, 0.0)
torch.nn.init.xavier_normal_(self.filter_5_5.weight)
torch.nn.init.constant_(self.filter_5_5.bias, 0.0)
torch.nn.init.xavier_normal_(self.filter_7_7.weight)
torch.nn.init.constant_(self.filter_7_7.bias, 0.0)
torch.nn.init.constant_(self.acc_w.weight, 1.0)
torch.nn.init.constant_(self.acc_w.bias, 0.0)
self.acc_w.weight.requires_grad = False
self.acc_w.bias.requires_grad = False
def forward(self, x):
# Contextual Reweighting Network
x_crn = self.downsample_pool(x)
# Compute multiscale context filters g_n
g_3 = self.filter_3_3(x_crn)
g_5 = self.filter_5_5(x_crn)
g_7 = self.filter_7_7(x_crn)
g = torch.cat((g_3, g_5, g_7), dim=1)
g = F.relu(g)
w = F.relu(self.acc_w(g)) # Accumulation weight
mask = self.upsample(w, scale_factor=2, mode='bilinear') # Reweighting Mask
return mask
class CRN(NetVLAD):
def __init__(self, clusters_num=64, dim=128, normalize_input=True):
super().__init__(clusters_num, dim, normalize_input)
self.crn = CRNModule(dim)
def forward(self, x):
N, D, H, W = x.shape[:]
if self.normalize_input:
x = F.normalize(x, p=2, dim=1) # Across descriptor dim
mask = self.crn(x)
x_flatten = x.view(N, D, -1)
soft_assign = self.conv(x).view(N, self.clusters_num, -1)
soft_assign = F.softmax(soft_assign, dim=1)
# Weight soft_assign using CRN's mask
soft_assign = soft_assign * mask.view(N, 1, H * W)
vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device)
for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage
residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
self.centroids[D:D + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
residual = residual * soft_assign[:, D:D + 1, :].unsqueeze(2)
vlad[:, D:D + 1, :] = residual.sum(dim=-1)
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization
vlad = vlad.view(N, -1) # Flatten
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize
return vlad
================================================
FILE: model/cct/__init__.py
================================================
from .cct import cct_14_7x2_384, cct_14_7x2_224
================================================
FILE: model/cct/cct.py
================================================
from torch.hub import load_state_dict_from_url
import torch.nn as nn
import torch
import torch.nn.functional as F
from .transformers import TransformerClassifier
from .tokenizer import Tokenizer
from .helpers import pe_check
from timm.models.registry import register_model
model_urls = {
'cct_7_3x1_32':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar10_300epochs.pth',
'cct_7_3x1_32_sine':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar10_5000epochs.pth',
'cct_7_3x1_32_c100':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar100_300epochs.pth',
'cct_7_3x1_32_sine_c100':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar100_5000epochs.pth',
'cct_7_7x2_224_sine':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_7x2_224_flowers102.pth',
'cct_14_7x2_224':
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_14_7x2_224_imagenet.pth',
'cct_14_7x2_384':
'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_imagenet.pth',
'cct_14_7x2_384_fl':
'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_flowers102.pth',
}
class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
dropout=0.,
attention_dropout=0.1,
stochastic_depth=0.1,
num_layers=14,
num_heads=6,
mlp_ratio=4.0,
num_classes=1000,
positional_embedding='learnable',
aggregation=None,
*args, **kwargs):
super(CCT, self).__init__()
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
max_pool=True,
activation=nn.ReLU,
n_conv_layers=n_conv_layers,
conv_bias=False)
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_size,
width=img_size),
embedding_dim=embedding_dim,
seq_pool=True,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth=stochastic_depth,
num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
num_classes=num_classes,
positional_embedding=positional_embedding
)
if aggregation in ['cls', 'seqpool']:
self.aggregation = aggregation
else:
self.aggregation = None
def forward(self, x):
x = self.tokenizer(x)
x = self.classifier(x)
if self.aggregation == 'cls':
return x[:, 0]
elif self.aggregation == 'seqpool':
x = torch.matmul(F.softmax(self.classifier.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
return x
else:
# x = x.permute(0, 2, 1)
return x
def _cct(arch, pretrained, progress,
num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
aggregation=None, *args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
model = CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
embedding_dim=embedding_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
aggregation=aggregation,
*args, **kwargs)
if pretrained:
if arch in model_urls:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
state_dict = pe_check(model, state_dict)
model.load_state_dict(state_dict, strict=False)
else:
raise RuntimeError(f'Variant {arch} does not yet have pretrained weights.')
return model
def cct_2(arch, pretrained, progress, aggregation=None, *args, **kwargs):
return _cct(arch, pretrained, progress, num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
aggregation=aggregation, *args, **kwargs)
def cct_4(arch, pretrained, progress, aggregation=None, *args, **kwargs):
return _cct(arch, pretrained, progress, num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
aggregation=aggregation, *args, **kwargs)
def cct_6(arch, pretrained, progress, aggregation=None, *args, **kwargs):
return _cct(arch, pretrained, progress, num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
aggregation=aggregation, *args, **kwargs)
def cct_7(arch, pretrained, progress, aggregation=None, *args, **kwargs):
return _cct(arch, pretrained, progress, num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
aggregation=aggregation, *args, **kwargs)
def cct_14(arch, pretrained, progress, aggregation=None, *args, **kwargs):
return _cct(arch, pretrained, progress, num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
aggregation=aggregation, *args, **kwargs)
@register_model
def cct_2_3x2_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_2('cct_2_3x2_32', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_2_3x2_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_2('cct_2_3x2_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_4_3x2_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_4('cct_4_3x2_32', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_4_3x2_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_4('cct_4_3x2_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_6_3x1_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_6('cct_6_3x1_32', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_6_3x1_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_6('cct_6_3x1_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_6_3x2_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_6('cct_6_3x2_32', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_6_3x2_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_6('cct_6_3x2_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x1_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x1_32', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x1_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x1_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x1_32_c100(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=100,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x1_32_c100', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x1_32_sine_c100(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=100,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x1_32_sine_c100', pretrained, progress,
kernel_size=3, n_conv_layers=1,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x2_32(pretrained=False, progress=False,
img_size=32, positional_embedding='learnable', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x2_32', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_3x2_32_sine(pretrained=False, progress=False,
img_size=32, positional_embedding='sine', num_classes=10,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_3x2_32_sine', pretrained, progress,
kernel_size=3, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_7x2_224(pretrained=False, progress=False,
img_size=224, positional_embedding='learnable', num_classes=102,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_7x2_224', pretrained, progress,
kernel_size=7, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_7_7x2_224_sine(pretrained=False, progress=False,
img_size=224, positional_embedding='sine', num_classes=102,
aggregation=None, *args, **kwargs):
return cct_7('cct_7_7x2_224_sine', pretrained, progress,
kernel_size=7, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_14_7x2_224(pretrained=False, progress=False,
img_size=224, positional_embedding='learnable', num_classes=1000,
aggregation=None, *args, **kwargs):
return cct_14('cct_14_7x2_224', pretrained, progress,
kernel_size=7, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_14_7x2_384(pretrained=False, progress=False,
img_size=384, positional_embedding='learnable', num_classes=1000,
aggregation=None, *args, **kwargs):
return cct_14('cct_14_7x2_384', pretrained, progress,
kernel_size=7, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
@register_model
def cct_14_7x2_384_fl(pretrained=False, progress=False,
img_size=384, positional_embedding='learnable', num_classes=102,
aggregation=None, *args, **kwargs):
return cct_14('cct_14_7x2_384_fl', pretrained, progress,
kernel_size=7, n_conv_layers=2,
img_size=img_size, positional_embedding=positional_embedding,
num_classes=num_classes, aggregation=aggregation,
*args, **kwargs)
================================================
FILE: model/cct/embedder.py
================================================
import torch.nn as nn
class Embedder(nn.Module):
def __init__(self,
word_embedding_dim=300,
vocab_size=100000,
padding_idx=1,
pretrained_weight=None,
embed_freeze=False,
*args, **kwargs):
super(Embedder, self).__init__()
self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \
if pretrained_weight is not None else \
nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
self.embeddings.weight.requires_grad = not embed_freeze
def forward_mask(self, mask):
bsz, seq_len = mask.shape
new_mask = mask.view(bsz, seq_len, 1)
new_mask = new_mask.sum(-1)
new_mask = (new_mask > 0)
return new_mask
def forward(self, x, mask=None):
embed = self.embeddings(x)
embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()
return embed, mask
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
else:
nn.init.normal_(m.weight)
================================================
FILE: model/cct/helpers.py
================================================
import math
import torch
import torch.nn.functional as F
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
# Copied from `timm` by Ross Wightman:
# github.com/rwightman/pytorch-image-models
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= num_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def pe_check(model, state_dict, pe_key='classifier.positional_emb'):
if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys():
if model.state_dict()[pe_key].shape != state_dict[pe_key].shape:
state_dict[pe_key] = resize_pos_embed(state_dict[pe_key],
model.state_dict()[pe_key],
num_tokens=model.classifier.num_tokens)
return state_dict
================================================
FILE: model/cct/stochastic_depth.py
================================================
# Thanks to rwightman's timm package
# github.com:rwightman/pytorch-image-models
import torch
import torch.nn as nn
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: model/cct/tokenizer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class Tokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
])
self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
class TextTokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
embedding_dim=300,
n_output_channels=128,
activation=None,
max_pool=True,
*args, **kwargs):
super(TextTokenizer, self).__init__()
self.max_pool = max_pool
self.conv_layers = nn.Sequential(
nn.Conv2d(1, n_output_channels,
kernel_size=(kernel_size, embedding_dim),
stride=(stride, 1),
padding=(padding, 0), bias=False),
nn.Identity() if activation is None else activation(),
nn.MaxPool2d(
kernel_size=(pooling_kernel_size, 1),
stride=(pooling_stride, 1),
padding=(pooling_padding, 0)
) if max_pool else nn.Identity()
)
self.apply(self.init_weight)
def seq_len(self, seq_len=32, embed_dim=300):
return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
def forward_mask(self, mask):
new_mask = mask.unsqueeze(1).float()
cnn_weight = torch.ones(
(1, 1, self.conv_layers[0].kernel_size[0]),
device=mask.device,
dtype=torch.float)
new_mask = F.conv1d(
new_mask, cnn_weight, None,
self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1)
if self.max_pool:
new_mask = F.max_pool1d(
new_mask, self.conv_layers[2].kernel_size[0],
self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False)
new_mask = new_mask.squeeze(1)
new_mask = (new_mask > 0)
return new_mask
def forward(self, x, mask=None):
x = x.unsqueeze(1)
x = self.conv_layers(x)
x = x.transpose(1, 3).squeeze(1)
x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float()
return x, mask
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
================================================
FILE: model/cct/transformers.py
================================================
import torch
from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init
import torch.nn.functional as F
from .stochastic_depth import DropPath
class Attention(Module):
"""
Obtained from timm: github.com:rwightman/pytorch-image-models
"""
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = Linear(dim, dim * 3, bias=False)
self.attn_drop = Dropout(attention_dropout)
self.proj = Linear(dim, dim)
self.proj_drop = Dropout(projection_dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaskedAttention(Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.scale = head_dim ** -0.5
self.qkv = Linear(dim, dim * 3, bias=False)
self.attn_drop = Dropout(attention_dropout)
self.proj = Linear(dim, dim)
self.proj_drop = Dropout(projection_dropout)
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
mask_value = -torch.finfo(attn.dtype).max
assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
attn.masked_fill_(~mask, mask_value)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and timm.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.pre_norm = LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, d_model)
self.dropout2 = Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
class MaskedTransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and timm.
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(MaskedTransformerEncoderLayer, self).__init__()
self.pre_norm = LayerNorm(d_model)
self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, d_model)
self.dropout2 = Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
self.activation = F.gelu
def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
class TransformerClassifier(Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout=0.1,
attention_dropout=0.1,
stochastic_depth=0.1,
positional_embedding='learnable',
sequence_length=None):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
assert sequence_length is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
else:
self.attention_pool = Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
self.positional_emb = None
self.dropout = Dropout(p=dropout)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]
self.blocks = ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = LayerNorm(embedding_dim)
# self.fc = Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# TODO: TOREMOVE
# if self.seq_pool:
# x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
#else:
# x = x[:, 0]
# x = self.fc(x)
return x
@staticmethod
def init_weight(m):
if isinstance(m, Linear):
init.trunc_normal_(m.weight, std=.02)
if isinstance(m, Linear) and m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
init.constant_(m.bias, 0)
init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
class MaskedTransformerClassifier(Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout=0.1,
attention_dropout=0.1,
stochastic_depth=0.1,
positional_embedding='sine',
seq_len=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.seq_len = seq_len
self.seq_pool = seq_pool
assert seq_len is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
seq_len += 1
self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
else:
self.attention_pool = Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
seq_len += 1 # padding idx
self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim),
requires_grad=True)
init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len,
embedding_dim,
padding_idx=True),
requires_grad=False)
else:
self.positional_emb = None
self.dropout = Dropout(p=dropout)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]
self.blocks = ModuleList([
MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = LayerNorm(embedding_dim)
self.fc = Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x, mask=None):
if self.positional_emb is None and x.size(1) < self.seq_len:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
if mask is not None:
mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1)
mask = (mask > 0)
if self.positional_emb is not None:
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x, mask=mask)
x = self.norm(x)
if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
else:
x = x[:, 0]
x = self.fc(x)
return x
@staticmethod
def init_weight(m):
if isinstance(m, Linear):
init.trunc_normal_(m.weight, std=.02)
if isinstance(m, Linear) and m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
init.constant_(m.bias, 0)
init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim, padding_idx=False):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
pe = pe.unsqueeze(0)
if padding_idx:
return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1)
return pe
================================================
FILE: model/functional.py
================================================
import math
import torch
import torch.nn.functional as F
def sare_ind(query, positive, negative):
'''all 3 inputs are supposed to be shape 1xn_features'''
dist_pos = ((query - positive)**2).sum(1)
dist_neg = ((query - negative)**2).sum(1)
dist = - torch.cat((dist_pos, dist_neg))
dist = F.log_softmax(dist, 0)
#loss = (- dist[:, 0]).mean() on a batch
loss = -dist[0]
return loss
def sare_joint(query, positive, negatives):
'''query and positive have to be 1xn_features; whereas negatives has to be
shape n_negative x n_features. n_negative is usually 10'''
# NOTE: the implementation is the same if batch_size=1 as all operations
# are vectorial. If there were the additional n_batch dimension a different
# handling of that situation would have to be implemented here.
# This function is declared anyway for the sake of clarity as the 2 should
# be called in different situations because, even though there would be
# no Exceptions, there would actually be a conceptual error.
return sare_ind(query, positive, negatives)
def mac(x):
return F.adaptive_max_pool2d(x, (1,1))
def spoc(x):
return F.adaptive_avg_pool2d(x, (1,1))
def gem(x, p=3, eps=1e-6, work_with_tokens=False):
if work_with_tokens:
x = x.permute(0, 2, 1)
# unseqeeze to maintain compatibility with Flatten
return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3)
else:
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
def rmac(x, L=3, eps=1e-6):
ovr = 0.4 # desired overlap of neighboring regions
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
W = x.size(3)
H = x.size(2)
w = min(W, H)
# w2 = math.floor(w/2.0 - 1)
b = (max(H, W)-w)/(steps-1)
(tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
# region overplus per dimension
Wd = 0;
Hd = 0;
if H < W:
Wd = idx.item() + 1
elif H > W:
Hd = idx.item() + 1
v = F.max_pool2d(x, (x.size(-2), x.size(-1)))
v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v)
for l in range(1, L+1):
wl = math.floor(2*w/(l+1))
wl2 = math.floor(wl/2 - 1)
if l+Wd == 1:
b = 0
else:
b = (W-wl)/(l+Wd-1)
cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates
if l+Hd == 1:
b = 0
else:
b = (H-wl)/(l+Hd-1)
cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates
for i_ in cenH.tolist():
for j_ in cenW.tolist():
if wl == 0:
continue
R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:]
R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()]
vt = F.max_pool2d(R, (R.size(-2), R.size(-1)))
vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt)
v += vt
return v
================================================
FILE: model/network.py
================================================
import os
import torch
import logging
import torchvision
from torch import nn
from os.path import join
from transformers import ViTModel
from google_drive_downloader import GoogleDriveDownloader as gdd
from model.cct import cct_14_7x2_384
from model.aggregation import Flatten
from model.normalization import L2Norm
import model.aggregation as aggregation
# Pretrained models on Google Landmarks v2 and Places 365
PRETRAINED_MODELS = {
'resnet18_places' : '1DnEQXhmPxtBUrRc81nAvT8z17bk-GBj5',
'resnet50_places' : '1zsY4mN4jJ-AsmV3h4hjbT72CBfJsgSGC',
'resnet101_places' : '1E1ibXQcg7qkmmmyYgmwMTh7Xf1cDNQXa',
'vgg16_places' : '1UWl1uz6rZ6Nqmp1K5z3GHAIZJmDh4bDu',
'resnet18_gldv2' : '1wkUeUXFXuPHuEvGTXVpuP5BMB-JJ1xke',
'resnet50_gldv2' : '1UDUv6mszlXNC1lv6McLdeBNMq9-kaA70',
'resnet101_gldv2' : '1apiRxMJpDlV0XmKlC5Na_Drg2jtGL-uE',
'vgg16_gldv2' : '10Ov9JdO7gbyz6mB5x0v_VSAUMj91Ta4o'
}
class GeoLocalizationNet(nn.Module):
"""The used networks are composed of a backbone and an aggregation layer.
"""
def __init__(self, args):
super().__init__()
self.backbone = get_backbone(args)
self.arch_name = args.backbone
self.aggregation = get_aggregation(args)
if args.aggregation in ["gem", "spoc", "mac", "rmac"]:
if args.l2 == "before_pool":
self.aggregation = nn.Sequential(L2Norm(), self.aggregation, Flatten())
elif args.l2 == "after_pool":
self.aggregation = nn.Sequential(self.aggregation, L2Norm(), Flatten())
elif args.l2 == "none":
self.aggregation = nn.Sequential(self.aggregation, Flatten())
if args.fc_output_dim != None:
# Concatenate fully connected layer to the aggregation layer
self.aggregation = nn.Sequential(self.aggregation,
nn.Linear(args.features_dim, args.fc_output_dim),
L2Norm())
args.features_dim = args.fc_output_dim
def forward(self, x):
x = self.backbone(x)
x = self.aggregation(x)
return x
def get_aggregation(args):
if args.aggregation == "gem":
return aggregation.GeM(work_with_tokens=args.work_with_tokens)
elif args.aggregation == "spoc":
return aggregation.SPoC()
elif args.aggregation == "mac":
return aggregation.MAC()
elif args.aggregation == "rmac":
return aggregation.RMAC()
elif args.aggregation == "netvlad":
return aggregation.NetVLAD(clusters_num=args.netvlad_clusters, dim=args.features_dim,
work_with_tokens=args.work_with_tokens)
elif args.aggregation == 'crn':
return aggregation.CRN(clusters_num=args.netvlad_clusters, dim=args.features_dim)
elif args.aggregation == "rrm":
return aggregation.RRM(args.features_dim)
elif args.aggregation in ['cls', 'seqpool']:
return nn.Identity()
def get_pretrained_model(args):
if args.pretrain == 'places': num_classes = 365
elif args.pretrain == 'gldv2': num_classes = 512
if args.backbone.startswith("resnet18"):
model = torchvision.models.resnet18(num_classes=num_classes)
elif args.backbone.startswith("resnet50"):
model = torchvision.models.resnet50(num_classes=num_classes)
elif args.backbone.startswith("resnet101"):
model = torchvision.models.resnet101(num_classes=num_classes)
elif args.backbone.startswith("vgg16"):
model = torchvision.models.vgg16(num_classes=num_classes)
if args.backbone.startswith('resnet'):
model_name = args.backbone.split('conv')[0] + "_" + args.pretrain
else:
model_name = args.backbone + "_" + args.pretrain
file_path = join("data", "pretrained_nets", model_name +".pth")
if not os.path.exists(file_path):
gdd.download_file_from_google_drive(file_id=PRETRAINED_MODELS[model_name],
dest_path=file_path)
state_dict = torch.load(file_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
return model
def get_backbone(args):
# The aggregation layer works differently based on the type of architecture
args.work_with_tokens = args.backbone.startswith('cct') or args.backbone.startswith('vit')
if args.backbone.startswith("resnet"):
if args.pretrain in ['places', 'gldv2']:
backbone = get_pretrained_model(args)
elif args.backbone.startswith("resnet18"):
backbone = torchvision.models.resnet18(pretrained=True)
elif args.backbone.startswith("resnet50"):
backbone = torchvision.models.resnet50(pretrained=True)
elif args.backbone.startswith("resnet101"):
backbone = torchvision.models.resnet101(pretrained=True)
for name, child in backbone.named_children():
# Freeze layers before conv_3
if name == "layer3":
break
for params in child.parameters():
params.requires_grad = False
if args.backbone.endswith("conv4"):
logging.debug(f"Train only conv4_x of the resnet{args.backbone.split('conv')[0]} (remove conv5_x), freeze the previous ones")
layers = list(backbone.children())[:-3]
elif args.backbone.endswith("conv5"):
logging.debug(f"Train only conv4_x and conv5_x of the resnet{args.backbone.split('conv')[0]}, freeze the previous ones")
layers = list(backbone.children())[:-2]
elif args.backbone == "vgg16":
if args.pretrain in ['places', 'gldv2']:
backbone = get_pretrained_model(args)
else:
backbone = torchvision.models.vgg16(pretrained=True)
layers = list(backbone.features.children())[:-2]
for l in layers[:-5]:
for p in l.parameters(): p.requires_grad = False
logging.debug("Train last layers of the vgg16, freeze the previous ones")
elif args.backbone == "alexnet":
backbone = torchvision.models.alexnet(pretrained=True)
layers = list(backbone.features.children())[:-2]
for l in layers[:5]:
for p in l.parameters(): p.requires_grad = False
logging.debug("Train last layers of the alexnet, freeze the previous ones")
elif args.backbone.startswith("cct"):
if args.backbone.startswith("cct384"):
backbone = cct_14_7x2_384(pretrained=True, progress=True, aggregation=args.aggregation)
if args.trunc_te:
logging.debug(f"Truncate CCT at transformers encoder {args.trunc_te}")
backbone.classifier.blocks = torch.nn.ModuleList(backbone.classifier.blocks[:args.trunc_te].children())
if args.freeze_te:
logging.debug(f"Freeze all the layers up to tranformer encoder {args.freeze_te}")
for p in backbone.parameters():
p.requires_grad = False
for name, child in backbone.classifier.blocks.named_children():
if int(name) > args.freeze_te:
for params in child.parameters():
params.requires_grad = True
args.features_dim = 384
return backbone
elif args.backbone.startswith("vit"):
assert args.resize[0] in [224, 384], f'Image size for ViT must be either 224 or 384, but it\'s {args.resize[0]}'
if args.resize[0] == 224:
backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
elif args.resize[0] == 384:
backbone = ViTModel.from_pretrained('google/vit-base-patch16-384')
if args.trunc_te:
logging.debug(f"Truncate ViT at transformers encoder {args.trunc_te}")
backbone.encoder.layer = backbone.encoder.layer[:args.trunc_te]
if args.freeze_te:
logging.debug(f"Freeze all the layers up to tranformer encoder {args.freeze_te+1}")
for p in backbone.parameters():
p.requires_grad = False
for name, child in backbone.encoder.layer.named_children():
if int(name) > args.freeze_te:
for params in child.parameters():
params.requires_grad = True
backbone = VitWrapper(backbone, args.aggregation)
args.features_dim = 768
return backbone
backbone = torch.nn.Sequential(*layers)
args.features_dim = get_output_channels_dim(backbone) # Dinamically obtain number of channels in output
return backbone
class VitWrapper(nn.Module):
def __init__(self, vit_model, aggregation):
super().__init__()
self.vit_model = vit_model
self.aggregation = aggregation
def forward(self, x):
if self.aggregation in ["netvlad", "gem"]:
return self.vit_model(x).last_hidden_state[:, 1:, :]
else:
return self.vit_model(x).last_hidden_state[:, 0, :]
def get_output_channels_dim(model):
"""Return the number of channels in the output of a model."""
return model(torch.ones([1, 3, 224, 224])).shape[1]
================================================
FILE: model/normalization.py
================================================
import torch.nn as nn
import torch.nn.functional as F
class L2Norm(nn.Module):
def __init__(self, dim=1):
super().__init__()
self.dim = dim
def forward(self, x):
return F.normalize(x, p=2, dim=self.dim)
================================================
FILE: model/sync_batchnorm/__init__.py
================================================
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import set_sbn_eps_mode
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .batchnorm import patch_sync_batchnorm, convert_model
from .replicate import DataParallelWithCallback, patch_replication_callback
================================================
FILE: model/sync_batchnorm/batchnorm.py
================================================
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import contextlib
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
try:
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
except ImportError:
ReduceAddCoalesced = Broadcast = None
try:
from jactorch.parallel.comm import SyncMaster
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
except ImportError:
from .comm import SyncMaster
from .replicate import DataParallelWithCallback
__all__ = [
'set_sbn_eps_mode',
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
'patch_sync_batchnorm', 'convert_model'
]
SBN_EPS_MODE = 'clamp'
def set_sbn_eps_mode(mode):
global SBN_EPS_MODE
assert mode in ('clamp', 'plus')
SBN_EPS_MODE = mode
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dimensions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
if not self.track_running_stats:
import warnings
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
if hasattr(torch, 'no_grad'):
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
else:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
if SBN_EPS_MODE == 'clamp':
return mean, bias_var.clamp(self.eps) ** -0.5
elif SBN_EPS_MODE == 'plus':
return mean, (bias_var + self.eps) ** -0.5
else:
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
@contextlib.contextmanager
def patch_sync_batchnorm():
import torch.nn as nn
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.BatchNorm1d = SynchronizedBatchNorm1d
nn.BatchNorm2d = SynchronizedBatchNorm2d
nn.BatchNorm3d = SynchronizedBatchNorm3d
yield
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
def convert_model(module):
"""Traverse the input module and its child recursively
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
to SynchronizedBatchNorm*N*d
Args:
module: the input module needs to be convert to SyncBN model
Examples:
>>> import torch.nn as nn
>>> import torchvision
>>> # m is a standard pytorch model
>>> m = torchvision.models.resnet18(True)
>>> m = nn.DataParallel(m)
>>> # after convert, m is using SyncBN
>>> m = convert_model(m)
"""
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
return mod
mod = module
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d,
SynchronizedBatchNorm3d]):
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child))
return mod
================================================
FILE: model/sync_batchnorm/batchnorm_reimpl.py
================================================
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : batchnorm_reimpl.py
# Author : acgtyrant
# Date : 11/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import torch
import torch.nn as nn
import torch.nn.init as init
__all__ = ['BatchNorm2dReimpl']
class BatchNorm2dReimpl(nn.Module):
"""
A re-implementation of batch normalization, used for testing the numerical
stability.
Author: acgtyrant
See also:
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.weight = nn.Parameter(torch.empty(num_features))
self.bias = nn.Parameter(torch.empty(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
def reset_parameters(self):
self.reset_running_stats()
init.uniform_(self.weight)
init.zeros_(self.bias)
def forward(self, input_):
batchsize, channels, height, width = input_.size()
numel = batchsize * height * width
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
sum_ = input_.sum(1)
sum_of_square = input_.pow(2).sum(1)
mean = sum_ / numel
sumvar = sum_of_square - sum_ * mean
self.running_mean = (
(1 - self.momentum) * self.running_mean
+ self.momentum * mean.detach()
)
unbias_var = sumvar / (numel - 1)
self.running_var = (
(1 - self.momentum) * self.running_var
+ self.momentum * unbias_var.detach()
)
bias_var = sumvar / numel
inv_std = 1 / (bias_var + self.eps).pow(0.5)
output = (
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
================================================
FILE: model/sync_batchnorm/comm.py
================================================
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
================================================
FILE: model/sync_batchnorm/replicate.py
================================================
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
================================================
FILE: model/sync_batchnorm/unittest.py
================================================
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import torch
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, x, y):
adiff = float((x - y).abs().max())
if (y == 0).all():
rdiff = 'NaN'
else:
rdiff = float((adiff / y).abs().max())
message = (
'Tensor close check failed\n'
'adiff={}\n'
'rdiff={}\n'
).format(adiff, rdiff)
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
================================================
FILE: parser.py
================================================
import os
import torch
import argparse
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmarking Visual Geolocalization",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Training parameters
parser.add_argument("--train_batch_size", type=int, default=4,
help="Number of triplets (query, pos, negs) in a batch. Each triplet consists of 12 images")
parser.add_argument("--infer_batch_size", type=int, default=16,
help="Batch size for inference (caching and testing)")
parser.add_argument("--criterion", type=str, default='triplet', help='loss to be used',
choices=["triplet", "sare_ind", "sare_joint"])
parser.add_argument("--margin", type=float, default=0.1,
help="margin for the triplet loss")
parser.add_argument("--epochs_num", type=int, default=1000,
help="number of epochs to train for")
parser.add_argument("--patience", type=int, default=3)
parser.add_argument("--lr", type=float, default=0.00001, help="_")
parser.add_argument("--lr_crn_layer", type=float, default=5e-3, help="Learning rate for the CRN layer")
parser.add_argument("--lr_crn_net", type=float, default=5e-4, help="Learning rate to finetune pretrained network when using CRN")
parser.add_argument("--optim", type=str, default="adam", help="_", choices=["adam", "sgd"])
parser.add_argument("--cache_refresh_rate", type=int, default=1000,
help="How often to refresh cache, in number of queries")
parser.add_argument("--queries_per_epoch", type=int, default=5000,
help="How many queries to consider for one epoch. Must be multiple of cache_refresh_rate")
parser.add_argument("--negs_num_per_query", type=int, default=10,
help="How many negatives to consider per each query in the loss")
parser.add_argument("--neg_samples_num", type=int, default=1000,
help="How many negatives to use to compute the hardest ones")
parser.add_argument("--mining", type=str, default="partial", choices=["partial", "full", "random", "msls_weighted"])
# Model parameters
parser.add_argument("--backbone", type=str, default="resnet18conv4",
choices=["alexnet", "vgg16", "resnet18conv4", "resnet18conv5",
"resnet50conv4", "resnet50conv5", "resnet101conv4", "resnet101conv5",
"cct384", "vit"], help="_")
parser.add_argument("--l2", type=str, default="before_pool", choices=["before_pool", "after_pool", "none"],
help="When (and if) to apply the l2 norm with shallow aggregation layers")
parser.add_argument("--aggregation", type=str, default="netvlad", choices=["netvlad", "gem", "spoc", "mac", "rmac", "crn", "rrm",
"cls", "seqpool"])
parser.add_argument('--netvlad_clusters', type=int, default=64, help="Number of clusters for NetVLAD layer.")
parser.add_argument('--pca_dim', type=int, default=None, help="PCA dimension (number of principal components). If None, PCA is not used.")
parser.add_argument('--fc_output_dim', type=int, default=None,
help="Output dimension of fully connected layer. If None, don't use a fully connected layer.")
parser.add_argument('--pretrain', type=str, default="imagenet", choices=['imagenet', 'gldv2', 'places'],
help="Select the pretrained weights for the starting network")
parser.add_argument("--off_the_shelf", type=str, default="imagenet", choices=["imagenet", "radenovic_sfm", "radenovic_gldv1", "naver"],
help="Off-the-shelf networks from popular GitHub repos. Only with ResNet-50/101 + GeM + FC 2048")
parser.add_argument("--trunc_te", type=int, default=None, choices=list(range(0, 14)))
parser.add_argument("--freeze_te", type=int, default=None, choices=list(range(-1, 14)))
# Initialization parameters
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--resume", type=str, default=None,
help="Path to load checkpoint from, for resuming training or testing.")
# Other parameters
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
parser.add_argument("--num_workers", type=int, default=8, help="num_workers for all dataloaders")
parser.add_argument('--resize', type=int, default=[480, 640], nargs=2, help="Resizing shape for images (HxW).")
parser.add_argument('--test_method', type=str, default="hard_resize",
choices=["hard_resize", "single_query", "central_crop", "five_crops", "nearest_crop", "maj_voting"],
help="This includes pre/post-processing methods and prediction refinement")
parser.add_argument("--majority_weight", type=float, default=0.01,
help="only for majority voting, scale factor, the higher it is the more importance is given to agreement")
parser.add_argument("--efficient_ram_testing", action='store_true', help="_")
parser.add_argument("--val_positive_dist_threshold", type=int, default=25, help="_")
parser.add_argument("--train_positives_dist_threshold", type=int, default=10, help="_")
parser.add_argument('--recall_values', type=int, default=[1, 5, 10, 20], nargs="+",
help="Recalls to be computed, such as R@5.")
# Data augmentation parameters
parser.add_argument("--brightness", type=float, default=0, help="_")
parser.add_argument("--contrast", type=float, default=0, help="_")
parser.add_argument("--saturation", type=float, default=0, help="_")
parser.add_argument("--hue", type=float, default=0, help="_")
parser.add_argument("--rand_perspective", type=float, default=0, help="_")
parser.add_argument("--horizontal_flip", action='store_true', help="_")
parser.add_argument("--random_resized_crop", type=float, default=0, help="_")
parser.add_argument("--random_rotation", type=float, default=0, help="_")
# Paths parameters
parser.add_argument("--datasets_folder", type=str, default=None, help="Path with all datasets")
parser.add_argument("--dataset_name", type=str, default="pitts30k", help="Relative path of the dataset")
parser.add_argument("--pca_dataset_folder", type=str, default=None,
help="Path with images to be used to compute PCA (ie: pitts30k/images/train")
parser.add_argument("--save_dir", type=str, default="default",
help="Folder name of the current run (saved in ./logs/)")
args = parser.parse_args()
if args.datasets_folder is None:
try:
args.datasets_folder = os.environ['DATASETS_FOLDER']
except KeyError:
raise Exception("You should set the parameter --datasets_folder or export " +
"the DATASETS_FOLDER environment variable as such \n" +
"export DATASETS_FOLDER=../datasets_vg/datasets")
if args.aggregation == "crn" and args.resume is None:
raise ValueError("CRN must be resumed from a trained NetVLAD checkpoint, but you set resume=None.")
if args.queries_per_epoch % args.cache_refresh_rate != 0:
raise ValueError("Ensure that queries_per_epoch is divisible by cache_refresh_rate, " +
f"because {args.queries_per_epoch} is not divisible by {args.cache_refresh_rate}")
if torch.cuda.device_count() >= 2 and args.criterion in ['sare_joint', "sare_ind"]:
raise NotImplementedError("SARE losses are not implemented for multiple GPUs, " +
f"but you're using {torch.cuda.device_count()} GPUs and {args.criterion} loss.")
if args.mining == "msls_weighted" and args.dataset_name != "msls":
raise ValueError("msls_weighted mining can only be applied to msls dataset, but you're using it on {args.dataset_name}")
if args.off_the_shelf in ["radenovic_sfm", "radenovic_gldv1", "naver"]:
if args.backbone not in ["resnet50conv5", "resnet101conv5"] or args.aggregation != "gem" or args.fc_output_dim != 2048:
raise ValueError("Off-the-shelf models are trained only with ResNet-50/101 + GeM + FC 2048")
if args.pca_dim is not None and args.pca_dataset_folder is None:
raise ValueError("Please specify --pca_dataset_folder when using pca")
if args.backbone == "vit":
if args.resize != [224, 224] and args.resize != [384, 384]:
raise ValueError(f'Image size for ViT must be either 224 or 384 {args.resize}')
if args.backbone == "cct384":
if args.resize != [384, 384]:
raise ValueError(f'Image size for CCT384 must be 384, but it is {args.resize}')
if args.backbone in ["alexnet", "vgg16", "resnet18conv4", "resnet18conv5",
"resnet50conv4", "resnet50conv5", "resnet101conv4", "resnet101conv5"]:
if args.aggregation in ["cls", "seqpool"]:
raise ValueError(f"CNNs like {args.backbone} can't work with aggregation {args.aggregation}")
if args.backbone in ["cct384"]:
if args.aggregation in ["spoc", "mac", "rmac", "crn", "rrm"]:
raise ValueError(f"CCT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls, seqpool]")
if args.backbone == "vit":
if args.aggregation not in ["cls", "gem", "netvlad"]:
raise ValueError(f"ViT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls]")
return args
================================================
FILE: requirements.txt
================================================
numpy==1.19.4
torchvision==0.8.1
psutil==5.6.7
faiss_cpu==1.5.3
tqdm==4.48.2
torch==1.7.0
Pillow==8.2.0
scikit_learn==0.24.1
torchscan==0.1.1
googledrivedownloader==0.4
requests==2.26.0
timm==0.4.12
transformers==4.8.2
einops
================================================
FILE: test.py
================================================
import faiss
import torch
import logging
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
def test_efficient_ram_usage(args, eval_ds, model, test_method="hard_resize"):
"""This function gives the same output as test(), but uses much less RAM.
This can be useful when testing with large descriptors (e.g. NetVLAD) on large datasets (e.g. San Francisco).
Obviously it is slower than test(), and can't be used with PCA.
"""
model = model.eval()
if test_method == 'nearest_crop' or test_method == "maj_voting":
distances = np.empty([eval_ds.queries_num * 5, eval_ds.database_num], dtype=np.float32)
else:
distances = np.empty([eval_ds.queries_num, eval_ds.database_num], dtype=np.float32)
with torch.no_grad():
if test_method == 'nearest_crop' or test_method == 'maj_voting':
queries_features = np.ones((eval_ds.queries_num * 5, args.features_dim), dtype="float32")
else:
queries_features = np.ones((eval_ds.queries_num, args.features_dim), dtype="float32")
logging.debug("Extracting queries features for evaluation/testing")
queries_infer_batch_size = 1 if test_method == "single_query" else args.infer_batch_size
eval_ds.test_method = test_method
queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
for inputs, indices in tqdm(queries_dataloader, ncols=100):
if test_method == "five_crops" or test_method == "nearest_crop" or test_method == 'maj_voting':
inputs = torch.cat(tuple(inputs)) # shape = 5*bs x 3 x 480 x 480
features = model(inputs.to(args.device))
if test_method == "five_crops": # Compute mean along the 5 crops
features = torch.stack(torch.split(features, 5)).mean(1)
if test_method == "nearest_crop" or test_method == 'maj_voting':
start_idx = (indices[0] - eval_ds.database_num) * 5
end_idx = start_idx + indices.shape[0] * 5
indices = np.arange(start_idx, end_idx)
queries_features[indices, :] = features.cpu().numpy()
else:
queries_features[indices.numpy()-eval_ds.database_num, :] = features.cpu().numpy()
queries_features = torch.tensor(queries_features).type(torch.float32).cuda()
logging.debug("Extracting database features for evaluation/testing")
# For database use "hard_resize", although it usually has no effect because database images have same resolution
eval_ds.test_method = "hard_resize"
database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
for inputs, indices in tqdm(database_dataloader, ncols=100):
inputs = inputs.to(args.device)
features = model(inputs)
for pn, (index, pred_feature) in enumerate(zip(indices, features)):
distances[:, index] = ((queries_features-pred_feature)**2).sum(1).cpu().numpy()
del features, queries_features, pred_feature
predictions = distances.argsort(axis=1)[:, :max(args.recall_values)]
if test_method == 'nearest_crop':
distances = np.array([distances[row, index] for row, index in enumerate(predictions)])
distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5))
predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5))
for q in range(eval_ds.queries_num):
# sort predictions by distance
sort_idx = np.argsort(distances[q])
predictions[q] = predictions[q, sort_idx]
# remove duplicated predictions, i.e. keep only the closest ones
_, unique_idx = np.unique(predictions[q], return_index=True)
# unique_idx is sorted based on the unique values, sort it again
predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20]
predictions = predictions[:, :20] # keep only the closer 20 predictions for each
elif test_method == 'maj_voting':
distances = np.array([distances[row, index] for row, index in enumerate(predictions)])
distances = np.reshape(distances, (eval_ds.queries_num, 5, 20))
predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20))
for q in range(eval_ds.queries_num):
# votings, modify distances in-place
top_n_voting('top1', predictions[q], distances[q], args.majority_weight)
top_n_voting('top5', predictions[q], distances[q], args.majority_weight)
top_n_voting('top10', predictions[q], distances[q], args.majority_weight)
# flatten dist and preds from 5, 20 -> 20*5
# and then proceed as usual to keep only first 20
dists = distances[q].flatten()
preds = predictions[q].flatten()
# sort predictions by distance
sort_idx = np.argsort(dists)
preds = preds[sort_idx]
# remove duplicated predictions, i.e. keep only the closest ones
_, unique_idx = np.unique(preds, return_index=True)
# unique_idx is sorted based on the unique values, sort it again
# here the row corresponding to the first crop is used as a
# 'buffer' for each query, and in the end the dimension
# relative to crops is eliminated
predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20]
predictions = predictions[:, 0, :20] # keep only the closer 20 predictions for each query
del distances
#### For each query, check if the predictions are correct
positives_per_query = eval_ds.get_positives()
# args.recall_values by default is [1, 5, 10, 20]
recalls = np.zeros(len(args.recall_values))
for query_index, pred in enumerate(predictions):
for i, n in enumerate(args.recall_values):
if np.any(np.in1d(pred[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
return recalls, recalls_str
def test(args, eval_ds, model, test_method="hard_resize", pca=None):
"""Compute features of the given dataset and compute the recalls."""
assert test_method in ["hard_resize", "single_query", "central_crop", "five_crops",
"nearest_crop", "maj_voting"], f"test_method can't be {test_method}"
if args.efficient_ram_testing:
return test_efficient_ram_usage(args, eval_ds, model, test_method)
model = model.eval()
with torch.no_grad():
logging.debug("Extracting database features for evaluation/testing")
# For database use "hard_resize", although it usually has no effect because database images have same resolution
eval_ds.test_method = "hard_resize"
database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
if test_method == "nearest_crop" or test_method == 'maj_voting':
all_features = np.empty((5 * eval_ds.queries_num + eval_ds.database_num, args.features_dim), dtype="float32")
else:
all_features = np.empty((len(eval_ds), args.features_dim), dtype="float32")
for inputs, indices in tqdm(database_dataloader, ncols=100):
features = model(inputs.to(args.device))
features = features.cpu().numpy()
if pca is not None:
features = pca.transform(features)
all_features[indices.numpy(), :] = features
logging.debug("Extracting queries features for evaluation/testing")
queries_infer_batch_size = 1 if test_method == "single_query" else args.infer_batch_size
eval_ds.test_method = test_method
queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
for inputs, indices in tqdm(queries_dataloader, ncols=100):
if test_method == "five_crops" or test_method == "nearest_crop" or test_method == 'maj_voting':
inputs = torch.cat(tuple(inputs)) # shape = 5*bs x 3 x 480 x 480
features = model(inputs.to(args.device))
if test_method == "five_crops": # Compute mean along the 5 crops
features = torch.stack(torch.split(features, 5)).mean(1)
features = features.cpu().numpy()
if pca is not None:
features = pca.transform(features)
if test_method == "nearest_crop" or test_method == 'maj_voting': # store the features of all 5 crops
start_idx = eval_ds.database_num + (indices[0] - eval_ds.database_num) * 5
end_idx = start_idx + indices.shape[0] * 5
indices = np.arange(start_idx, end_idx)
all_features[indices, :] = features
else:
all_features[indices.numpy(), :] = features
queries_features = all_features[eval_ds.database_num:]
database_features = all_features[:eval_ds.database_num]
faiss_index = faiss.IndexFlatL2(args.features_dim)
faiss_index.add(database_features)
del database_features, all_features
logging.debug("Calculating recalls")
distances, predictions = faiss_index.search(queries_features, max(args.recall_values))
if test_method == 'nearest_crop':
distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5))
predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5))
for q in range(eval_ds.queries_num):
# sort predictions by distance
sort_idx = np.argsort(distances[q])
predictions[q] = predictions[q, sort_idx]
# remove duplicated predictions, i.e. keep only the closest ones
_, unique_idx = np.unique(predictions[q], return_index=True)
# unique_idx is sorted based on the unique values, sort it again
predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20]
predictions = predictions[:, :20] # keep only the closer 20 predictions for each query
elif test_method == 'maj_voting':
distances = np.reshape(distances, (eval_ds.queries_num, 5, 20))
predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20))
for q in range(eval_ds.queries_num):
# votings, modify distances in-place
top_n_voting('top1', predictions[q], distances[q], args.majority_weight)
top_n_voting('top5', predictions[q], distances[q], args.majority_weight)
top_n_voting('top10', predictions[q], distances[q], args.majority_weight)
# flatten dist and preds from 5, 20 -> 20*5
# and then proceed as usual to keep only first 20
dists = distances[q].flatten()
preds = predictions[q].flatten()
# sort predictions by distance
sort_idx = np.argsort(dists)
preds = preds[sort_idx]
# remove duplicated predictions, i.e. keep only the closest ones
_, unique_idx = np.unique(preds, return_index=True)
# unique_idx is sorted based on the unique values, sort it again
# here the row corresponding to the first crop is used as a
# 'buffer' for each query, and in the end the dimension
# relative to crops is eliminated
predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20]
predictions = predictions[:, 0, :20] # keep only the closer 20 predictions for each query
#### For each query, check if the predictions are correct
positives_per_query = eval_ds.get_positives()
# args.recall_values by default is [1, 5, 10, 20]
recalls = np.zeros(len(args.recall_values))
for query_index, pred in enumerate(predictions):
for i, n in enumerate(args.recall_values):
if np.any(np.in1d(pred[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
# Divide by the number of queries*100, so the recalls are in percentages
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
return recalls, recalls_str
def top_n_voting(topn, predictions, distances, maj_weight):
if topn == 'top1':
n = 1
selected = 0
elif topn == 'top5':
n = 5
selected = slice(0, 5)
elif topn == 'top10':
n = 10
selected = slice(0, 10)
# find predictions that repeat in the first, first five,
# or fist ten columns for each crop
vals, counts = np.unique(predictions[:, selected], return_counts=True)
# for each prediction that repeats more than once,
# subtract from its score
for val, count in zip(vals[counts > 1], counts[counts > 1]):
mask = (predictions[:, selected] == val)
distances[:, selected][mask] -= maj_weight * count/n
================================================
FILE: train.py
================================================
import math
import torch
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import multiprocessing
from os.path import join
from datetime import datetime
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import util
import test
import parser
import commons
import datasets_ws
from model import network
from model.sync_batchnorm import convert_model
from model.functional import sare_ind, sare_joint
torch.backends.cudnn.benchmark = True # Provides a speedup
#### Initial setup: parser, logging...
args = parser.parse_arguments()
start_time = datetime.now()
args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S'))
commons.setup_logging(args.save_dir)
commons.make_deterministic(args.seed)
logging.info(f"Arguments: {args}")
logging.info(f"The outputs are being saved in {args.save_dir}")
logging.info(f"Using {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs")
#### Creation of Datasets
logging.debug(f"Loading dataset {args.dataset_name} from folder {args.datasets_folder}")
triplets_ds = datasets_ws.TripletsDataset(args, args.datasets_folder, args.dataset_name, "train", args.negs_num_per_query)
logging.info(f"Train query set: {triplets_ds}")
val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val")
logging.info(f"Val set: {val_ds}")
test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test")
logging.info(f"Test set: {test_ds}")
#### Initialize model
model = network.GeoLocalizationNet(args)
model = model.to(args.device)
if args.aggregation in ["netvlad", "crn"]: # If using NetVLAD layer, initialize it
if not args.resume:
triplets_ds.is_inference = True
model.aggregation.initialize_netvlad_layer(args, triplets_ds, model.backbone)
args.features_dim *= args.netvlad_clusters
model = torch.nn.DataParallel(model)
#### Setup Optimizer and Loss
if args.aggregation == "crn":
crn_params = list(model.module.aggregation.crn.parameters())
net_params = list(model.module.backbone.parameters()) + \
list([m[1] for m in model.module.aggregation.named_parameters() if not m[0].startswith('crn')])
if args.optim == "adam":
optimizer = torch.optim.Adam([{'params': crn_params, 'lr': args.lr_crn_layer},
{'params': net_params, 'lr': args.lr_crn_net}])
logging.info("You're using CRN with Adam, it is advised to use SGD")
elif args.optim == "sgd":
optimizer = torch.optim.SGD([{'params': crn_params, 'lr': args.lr_crn_layer, 'momentum': 0.9, 'weight_decay': 0.001},
{'params': net_params, 'lr': args.lr_crn_net, 'momentum': 0.9, 'weight_decay': 0.001}])
else:
if args.optim == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
elif args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001)
if args.criterion == "triplet":
criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum")
elif args.criterion == "sare_ind":
criterion_triplet = sare_ind
elif args.criterion == "sare_joint":
criterion_triplet = sare_joint
#### Resume model, optimizer, and other training parameters
if args.resume:
if args.aggregation != 'crn':
model, optimizer, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, optimizer)
else:
# CRN uses pretrained NetVLAD, then requires loading with strict=False and
# does not load the optimizer from the checkpoint file.
model, _, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, strict=False)
logging.info(f"Resuming from epoch {start_epoch_num} with best recall@5 {best_r5:.1f}")
else:
best_r5 = start_epoch_num = not_improved_num = 0
if args.backbone.startswith('vit'):
logging.info(f"Output dimension of the model is {args.features_dim}")
else:
logging.info(f"Output dimension of the model is {args.features_dim}, with {util.get_flops(model, args.resize)}")
if torch.cuda.device_count() >= 2:
# When using more than 1GPU, use sync_batchnorm for torch.nn.DataParallel
model = convert_model(model)
model = model.cuda()
#### Training loop
for epoch_num in range(start_epoch_num, args.epochs_num):
logging.info(f"Start training epoch: {epoch_num:02d}")
epoch_start_time = datetime.now()
epoch_losses = np.zeros((0, 1), dtype=np.float32)
# How many loops should an epoch last (default is 5000/1000=5)
loops_num = math.ceil(args.queries_per_epoch / args.cache_refresh_rate)
for loop_num in range(loops_num):
logging.debug(f"Cache: {loop_num} / {loops_num}")
# Compute triplets to use in the triplet loss
triplets_ds.is_inference = True
triplets_ds.compute_triplets(args, model)
triplets_ds.is_inference = False
triplets_dl = DataLoader(dataset=triplets_ds, num_workers=args.num_workers,
batch_size=args.train_batch_size,
collate_fn=datasets_ws.collate_fn,
pin_memory=(args.device == "cuda"),
drop_last=True)
model = model.train()
# images shape: (train_batch_size*12)*3*H*W ; by default train_batch_size=4, H=480, W=640
# triplets_local_indexes shape: (train_batch_size*10)*3 ; because 10 triplets per query
for images, triplets_local_indexes, _ in tqdm(triplets_dl, ncols=100):
# Flip all triplets or none
if args.horizontal_flip:
images = transforms.RandomHorizontalFlip()(images)
# Compute features of all images (images contains queries, positives and negatives)
features = model(images.to(args.device))
loss_triplet = 0
if args.criterion == "triplet":
triplets_local_indexes = torch.transpose(
triplets_local_indexes.view(args.train_batch_size, args.negs_num_per_query, 3), 1, 0)
for triplets in triplets_local_indexes:
queries_indexes, positives_indexes, negatives_indexes = triplets.T
loss_triplet += criterion_triplet(features[queries_indexes],
features[positives_indexes],
features[negatives_indexes])
elif args.criterion == 'sare_joint':
# sare_joint needs to receive all the negatives at once
triplet_index_batch = triplets_local_indexes.view(args.train_batch_size, 10, 3)
for batch_triplet_index in triplet_index_batch:
q = features[batch_triplet_index[0, 0]].unsqueeze(0) # obtain query as tensor of shape 1xn_features
p = features[batch_triplet_index[0, 1]].unsqueeze(0) # obtain positive as tensor of shape 1xn_features
n = features[batch_triplet_index[:, 2]] # obtain negatives as tensor of shape 10xn_features
loss_triplet += criterion_triplet(q, p, n)
elif args.criterion == "sare_ind":
for triplet in triplets_local_indexes:
# triplet is a 1-D tensor with the 3 scalars indexes of the triplet
q_i, p_i, n_i = triplet
loss_triplet += criterion_triplet(features[q_i:q_i+1], features[p_i:p_i+1], features[n_i:n_i+1])
del features
loss_triplet /= (args.train_batch_size * args.negs_num_per_query)
optimizer.zero_grad()
loss_triplet.backward()
optimizer.step()
# Keep track of all losses by appending them to epoch_losses
batch_loss = loss_triplet.item()
epoch_losses = np.append(epoch_losses, batch_loss)
del loss_triplet
logging.debug(f"Epoch[{epoch_num:02d}]({loop_num}/{loops_num}): " +
f"current batch triplet loss = {batch_loss:.4f}, " +
f"average epoch triplet loss = {epoch_losses.mean():.4f}")
logging.info(f"Finished epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, "
f"average epoch triplet loss = {epoch_losses.mean():.4f}")
# Compute recalls on validation set
recalls, recalls_str = test.test(args, val_ds, model)
logging.info(f"Recalls on val set {val_ds}: {recalls_str}")
is_best = recalls[1] > best_r5
# Save checkpoint, which contains all training parameters
util.save_checkpoint(args, {
"epoch_num": epoch_num, "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "recalls": recalls, "best_r5": best_r5,
"not_improved_num": not_improved_num
}, is_best, filename="last_model.pth")
# If recall@5 did not improve for "many" epochs, stop training
if is_best:
logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
best_r5 = recalls[1]
not_improved_num = 0
else:
not_improved_num += 1
logging.info(f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}")
if not_improved_num >= args.patience:
logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.")
break
logging.info(f"Best R@5: {best_r5:.1f}")
logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}")
#### Test best model on test set
best_model_state_dict = torch.load(join(args.save_dir, "best_model.pth"))["model_state_dict"]
model.load_state_dict(best_model_state_dict)
recalls, recalls_str = test.test(args, test_ds, model, test_method=args.test_method)
logging.info(f"Recalls on {test_ds}: {recalls_str}")
================================================
FILE: util.py
================================================
import re
import torch
import shutil
import logging
import torchscan
import numpy as np
from collections import OrderedDict
from os.path import join
from sklearn.decomposition import PCA
import datasets_ws
def get_flops(model, input_shape=(480, 640)):
"""Return the FLOPs as a string, such as '22.33 GFLOPs'"""
assert len(input_shape) == 2, f"input_shape should have len==2, but it's {input_shape}"
module_info = torchscan.crawl_module(model, (3, input_shape[0], input_shape[1]))
output = torchscan.utils.format_info(module_info)
return re.findall("Floating Point Operations on forward: (.*)\n", output)[0]
def save_checkpoint(args, state, is_best, filename):
model_path = join(args.save_dir, filename)
torch.save(state, model_path)
if is_best:
shutil.copyfile(model_path, join(args.save_dir, "best_model.pth"))
def resume_model(args, model):
checkpoint = torch.load(args.resume, map_location=args.device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
# The pre-trained models that we provide in the README do not have 'state_dict' in the keys as
# the checkpoint is directly the state dict
state_dict = checkpoint
# if the model contains the prefix "module" which is appendend by
# DataParallel, remove it to avoid errors when loading dict
if list(state_dict.keys())[0].startswith('module'):
state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in state_dict.items()})
model.load_state_dict(state_dict)
return model
def resume_train(args, model, optimizer=None, strict=False):
"""Load model, optimizer, and other training parameters"""
logging.debug(f"Loading checkpoint: {args.resume}")
checkpoint = torch.load(args.resume)
start_epoch_num = checkpoint["epoch_num"]
model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
best_r5 = checkpoint["best_r5"]
not_improved_num = checkpoint["not_improved_num"]
logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, "
f"current_best_R@5 = {best_r5:.1f}")
if args.resume.endswith("last_model.pth"): # Copy best model to current save_dir
shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.save_dir)
return model, optimizer, best_r5, start_epoch_num, not_improved_num
def compute_pca(args, model, pca_dataset_folder, full_features_dim):
model = model.eval()
pca_ds = datasets_ws.PCADataset(args, args.datasets_folder, pca_dataset_folder)
dl = torch.utils.data.DataLoader(pca_ds, args.infer_batch_size, shuffle=True)
pca_features = np.empty([min(len(pca_ds), 2**14), full_features_dim])
with torch.no_grad():
for i, images in enumerate(dl):
if i*args.infer_batch_size >= len(pca_features):
break
features = model(images).cpu().numpy()
pca_features[i*args.infer_batch_size : (i*args.infer_batch_size)+len(features)] = features
pca = PCA(args.pca_dim)
pca.fit(pca_features)
return pca