Repository: gmberton/deep-visual-geo-localization-benchmark Branch: main Commit: 4af519437403 Files: 29 Total size: 172.7 KB Directory structure: gitextract_q_wormse/ ├── .gitignore ├── LICENSE ├── README.md ├── commons.py ├── datasets_ws.py ├── eval.py ├── model/ │ ├── __init__.py │ ├── aggregation.py │ ├── cct/ │ │ ├── __init__.py │ │ ├── cct.py │ │ ├── embedder.py │ │ ├── helpers.py │ │ ├── stochastic_depth.py │ │ ├── tokenizer.py │ │ └── transformers.py │ ├── functional.py │ ├── network.py │ ├── normalization.py │ └── sync_batchnorm/ │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── parser.py ├── requirements.txt ├── test.py ├── train.py └── util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Autogenerated folders __pycache__ logs test data # IDEs generated folders .spyproject venv/ .idea/ __MACOSX/ **/.DS_Store # other pretrained *.pth ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2016-2019 VRG, CTU Prague Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Deep Visual Geo-localization Benchmark This is the official repository for the CVPR 2022 Oral paper [Deep Visual Geo-localization Benchmark](https://arxiv.org/abs/2204.03444). It can be used to reproduce results from the paper, and to compute a wide range of experiments, by changing the components of a Visual Geo-localization pipeline. ## Setup Before you begin experimenting with this toolbox, your dataset should be organized in a directory tree as such: ``` . ├── benchmarking_vg └── datasets_vg └── datasets └── pitts30k └── images ├── train │ ├── database │ └── queries ├── val │ ├── database │ └── queries └── test ├── database └── queries ``` The [VPR-datasets-downloader](https://github.com/gmberton/VPR-datasets-downloader) repo can be used to download a number of datasets. Detailed instructions on how to download datasets are in the repo. Note that many datasets are available, and _pitts30k_ is just an example. ## Running experiments ### Basic experiment For a basic experiment run `$ python3 train.py --dataset_name=pitts30k` this will train a ResNet-18 + NetVLAD on Pitts30k. The experiment creates a folder named `./logs/default/YYYY-MM-DD_HH-mm-ss`, where checkpoints are saved, as well as an `info.log` file with training logs and other information, such as model size, FLOPs and descriptors dimensionality. ### Architectures and mining You can replace the backbone and the aggregation as such `$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem` you can easily use ResNets cropped at conv4 or conv5. #### Add a fully connected layer To add a fully connected layer of dimension 2048 to GeM pooling: `$ python3 train.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=gem --fc_output_dim=2048` #### Add PCA To add PCA to a NetVLAD layer just do: `$ python3 eval.py --dataset_name=pitts30k --backbone=resnet50conv4 --aggregation=netvlad --pca_dim=2048 --pca_dataset_folder=pitts30k/images/train` where _pca_dataset_folder_ points to the folder with the images used to compute PCA. In the paper we compute PCA's principal components on the train set as it showed best results. PCA is used only at test time. #### Evaluate trained models To evaluate the trained model on other datasets (this example is with the St Lucia dataset), simply run `$ python3 eval.py --backbone=resnet50conv4 --aggregation=gem --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=st_lucia` #### Reproduce the results Finally, to reproduce our results, use the appropriate mining method: _full_ for _pitts30k_ and _partial_ for _msls_ as such: `$ python3 train.py --dataset_name=pitts30k --mining=full` As simple as this, you can replicate all results from tables 3, 4, 5 of the main paper, as well as tables 2, 3, 4 of the supplementary. ### Resize To resize the images simply pass the parameters _resize_ with the target resolution. For example, 80% of resolution to the full _pitts30k_ images, would be 384, 512, because the full images are 480, 640: `$ python3 train.py --dataset_name=pitts30k --resize=384 512` ### Query pre/post-processing and predictions refinement We gather all such methods under the _test_method_ parameter. The available methods are _hard_resize_, _single_query_, _central_crop_, _five_crops_mean_, _nearest_crop_ and _majority_voting_. Although _hard_resize_ is the default, in most datasets it doesn't apply any transformation at all (see the paper for more information), because all images have the same resolution. `$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --dataset_name=tokyo247 --test_method=nearest_crop` ### Data augmentation You can reproduce all data augmentation techniques from the paper with simple commands, for example: `$ python3 train.py --dataset_name=pitts30k --horizontal_flipping --saturation 2 --brightness 1` ### Off-the-shelf models trained on Landmark Recognition datasets The code allows to automatically download and use models trained on Landmark Recognition datasets from popular repositories: [radenovic](https://github.com/filipradenovic/cnnimageretrieval-pytorch) and [naver](https://github.com/naver/deep-image-retrieval). These repos offer ResNets-50/101 with GeM and FC 2048 trained on such datasets, and can be used as such: `$ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=r101l4 --aggregation=gem --fc_output_dim=2048` `$ python eval.py --dataset_name=pitts30k --off_the_shelf=naver --l2=none --backbone=r101l4 --aggregation=gem --fc_output_dim=2048` ### Using pretrained networks on other datasets Check out our [pretrain_vg](https://github.com/rm-wu/pretrain_vg) repo which we use to train such models. You can automatically download and train on those models as such `$ python train.py --dataset_name=pitts30k --pretrained=places` ### Changing the threshold distance You can use a different distance than the default 25 meters as simply as this (for example to 100 meters): `$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --val_positive_dist_threshold=100` ### Changing the recall values (R@N) By default the toolbox computes recalls@ 1, 5, 10, 20, but you can compute other recalls as such: `$ python3 eval.py --resume=logs/default/YYYY-MM-DD_HH-mm-ss/best_model.pth --recall_values 1 5 10 15 20 50 100` ### Model Zoo We are currently exploring hosting options, so this is a partial list of models. More models will be added soon!!
Pretrained models with different backbones
Pretained networks employing different backbones.

Model Training on Pitts30k Training on MSLS
Pitts30k (R@1) MSLS (R@1) Download Pitts30k (R@1) MSLS (R@1) Download
vgg16-gem 78.5 43.4 [Link] 70.2 66.7 [Link]
resnet18-gem 77.8 35.3 [Link] 71.6 65.3 [Link]
resnet50-gem 82.0 38.0 [Link] 77.4 72.0 [Link]
resnet101-gem 82.4 39.6 [Link] 77.2 72.5 [Link]
ViT(224)-CLS _ _ _ 80.4 69.3 [Link]
vgg16-netvlad 83.2 50.9 [Link] 79.0 74.6 [Link]
resnet18-netvlad 86.4 47.4 [Link] 81.6 75.8 [Link]
resnet50-netvlad 86.0 50.7 [Link] 80.9 76.9 [Link]
resnet101-netvlad 86.5 51.8 [Link] 80.8 77.7 [Link]
cct384-netvlad 85.0 52.5 [Link] 80.3 85.1 [Link]
Pretrained models with different aggregation methods
Pretrained networks trained using different aggregation methods.

Model Training on Pitts30k (R@1) Training on MSLS (R@1)
Pitts30k (R@1) MSLS (R@1) Download Pitts30k (R@1) MSLS (R@1) Download
resnet50-gem 82.0 38.0 [Link] 77.4 72.0 [Link]
resnet50-gem-fc2048 80.1 33.7 [Link] 79.2 73.5 [Link]
resnet50-gem-fc65536 80.8 35.8 [Link] 79.0 74.4 [Link]
resnet50-netvlad 86.0 50.7 [Link] 80.9 76.9 [Link]
resnet50-crn 85.8 54.0 [Link] 80.8 77.8 [Link]
Pretrained models with different mining methods
Pretained networks trained using three different mining methods (random, full database mining and partial database mining):

Model Training on Pitts30k (R@1) Training on MSLS (R@1)
Pitts30k (R@1) MSLS (R@1) Download Pitts30k (R@1) MSLS (R@1) Download
resnet18-gem-random 73.7 30.5 [Link] 62.2 50.6 [Link]
resnet18-gem-full 77.8 35.3 [Link] 70.161.8 [Link]
resnet18-gem-partial 76.5 34.2 [Link] 71.6 65.3 [Link]
resnet18-netvlad-random 83.9 43.6 [Link] 73.3 61.5 [Link]
resnet18-netvlad-full 86.4 47.4 [Link] -- -
resnet18-netvlad-partial 86.2 47.3 [Link] 81.6 75.8 [Link]
resnet50-gem-random 77.9 34.3 [Link] 69.5 57.4 [Link]
resnet50-gem-full 82.0 38.0 [Link] 77.3 69.7 [Link]
resnet50-gem-partial 82.3 39.0 [Link] 77.4 72.0 [Link]
resnet50-netvlad-random 83.4 45.0 [Link] 74.9 63.6 [Link]
resnet50-netvlad-full 86.0 50.7 [Link] -- -
resnet50-netvlad-partial 85.5 48.6 [Link] 80.9 76.9 [Link]
If you find our work useful in your research please consider citing our paper: ```bibtex @inproceedings{Berton_CVPR_2022_benchmark, author = {Berton, Gabriele and Mereu, Riccardo and Trivigno, Gabriele and Masone, Carlo and Csurka, Gabriela and Sattler, Torsten and Caputo, Barbara}, title = {Deep Visual Geo-Localization Benchmark}, booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition}, month = {June}, year = {2022} } ``` ## Acknowledgements Parts of this repo are inspired by the following great repositories: - [NetVLAD's original code](https://github.com/Relja/netvlad) (in MATLAB) - [NetVLAD layer in PyTorch](https://github.com/lyakaap/NetVLAD-pytorch) - [NetVLAD training in PyTorch](https://github.com/Nanne/pytorch-NetVlad/) - [GeM layer](https://github.com/filipradenovic/cnnimageretrieval-pytorch) - [Deep Image Retrieval](https://github.com/naver/deep-image-retrieval) - [Mapillary Street-level Sequences](https://github.com/mapillary/mapillary_sls) - [Compact Convolutional Transformers](https://github.com/SHI-Labs/Compact-Transformers) Check out also our other repo [_CosPlace_](https://github.com/gmberton/CosPlace), from the CVPR 2022 paper "Rethinking Visual Geo-localization for Large-Scale Applications", which provides a new SOTA in visual geo-localization / visual place recognition. ================================================ FILE: commons.py ================================================ """ This file contains some functions and classes which can be useful in very diverse projects. """ import os import sys import torch import random import logging import traceback import numpy as np from os.path import join def make_deterministic(seed=0): """Make results deterministic. If seed == -1, do not make deterministic. Running the script in a deterministic way might slow it down. """ if seed == -1: return random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def setup_logging(save_dir, console="debug", info_filename="info.log", debug_filename="debug.log"): """Set up logging files and console output. Creates one file for INFO logs and one for DEBUG logs. Args: save_dir (str): creates the folder where to save the files. debug (str): if == "debug" prints on console debug messages and higher if == "info" prints on console info messages and higher if == None does not use console (useful when a logger has already been set) info_filename (str): the name of the info file. if None, don't create info file debug_filename (str): the name of the debug file. if None, don't create debug file """ if os.path.exists(save_dir): raise FileExistsError(f"{save_dir} already exists!") os.makedirs(save_dir, exist_ok=True) # logging.Logger.manager.loggerDict.keys() to check which loggers are in use base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") logger = logging.getLogger('') logger.setLevel(logging.DEBUG) if info_filename is not None: info_file_handler = logging.FileHandler(join(save_dir, info_filename)) info_file_handler.setLevel(logging.INFO) info_file_handler.setFormatter(base_formatter) logger.addHandler(info_file_handler) if debug_filename is not None: debug_file_handler = logging.FileHandler(join(save_dir, debug_filename)) debug_file_handler.setLevel(logging.DEBUG) debug_file_handler.setFormatter(base_formatter) logger.addHandler(debug_file_handler) if console is not None: console_handler = logging.StreamHandler() if console == "debug": console_handler.setLevel(logging.DEBUG) if console == "info": console_handler.setLevel(logging.INFO) console_handler.setFormatter(base_formatter) logger.addHandler(console_handler) def exception_handler(type_, value, tb): logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) sys.excepthook = exception_handler ================================================ FILE: datasets_ws.py ================================================ import os import torch import faiss import logging import numpy as np from glob import glob from tqdm import tqdm from PIL import Image from os.path import join import torch.utils.data as data import torchvision.transforms as T from torch.utils.data.dataset import Subset from sklearn.neighbors import NearestNeighbors from torch.utils.data.dataloader import DataLoader base_transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def path_to_pil_img(path): return Image.open(path).convert("RGB") def collate_fn(batch): """Creates mini-batch tensors from the list of tuples (images, triplets_local_indexes, triplets_global_indexes). triplets_local_indexes are the indexes referring to each triplet within images. triplets_global_indexes are the global indexes of each image. Args: batch: list of tuple (images, triplets_local_indexes, triplets_global_indexes). considering each query to have 10 negatives (negs_num_per_query=10): - images: torch tensor of shape (12, 3, h, w). - triplets_local_indexes: torch tensor of shape (10, 3). - triplets_global_indexes: torch tensor of shape (12). Returns: images: torch tensor of shape (batch_size*12, 3, h, w). triplets_local_indexes: torch tensor of shape (batch_size*10, 3). triplets_global_indexes: torch tensor of shape (batch_size, 12). """ images = torch.cat([e[0] for e in batch]) triplets_local_indexes = torch.cat([e[1][None] for e in batch]) triplets_global_indexes = torch.cat([e[2][None] for e in batch]) for i, (local_indexes, global_indexes) in enumerate(zip(triplets_local_indexes, triplets_global_indexes)): local_indexes += len(global_indexes) * i # Increment local indexes by offset (len(global_indexes) is 12) return images, torch.cat(tuple(triplets_local_indexes)), triplets_global_indexes class PCADataset(data.Dataset): def __init__(self, args, datasets_folder="dataset", dataset_folder="pitts30k/images/train"): dataset_folder_full_path = join(datasets_folder, dataset_folder) if not os.path.exists(dataset_folder_full_path): raise FileNotFoundError(f"Folder {dataset_folder_full_path} does not exist") self.images_paths = sorted(glob(join(dataset_folder_full_path, "**", "*.jpg"), recursive=True)) def __getitem__(self, index): return base_transform(path_to_pil_img(self.images_paths[index])) def __len__(self): return len(self.images_paths) class BaseDataset(data.Dataset): """Dataset with images from database and queries, used for inference (testing and building cache). """ def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train"): super().__init__() self.args = args self.dataset_name = dataset_name self.dataset_folder = join(datasets_folder, dataset_name, "images", split) if not os.path.exists(self.dataset_folder): raise FileNotFoundError(f"Folder {self.dataset_folder} does not exist") self.resize = args.resize self.test_method = args.test_method #### Read paths and UTM coordinates for all images. database_folder = join(self.dataset_folder, "database") queries_folder = join(self.dataset_folder, "queries") if not os.path.exists(database_folder): raise FileNotFoundError(f"Folder {database_folder} does not exist") if not os.path.exists(queries_folder): raise FileNotFoundError(f"Folder {queries_folder} does not exist") self.database_paths = sorted(glob(join(database_folder, "**", "*.jpg"), recursive=True)) self.queries_paths = sorted(glob(join(queries_folder, "**", "*.jpg"), recursive=True)) # The format must be path/to/file/@utm_easting@utm_northing@...@.jpg self.database_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float) self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype(float) # Find soft_positives_per_query, which are within val_positive_dist_threshold (deafult 25 meters) knn = NearestNeighbors(n_jobs=-1) knn.fit(self.database_utms) self.soft_positives_per_query = knn.radius_neighbors(self.queries_utms, radius=args.val_positive_dist_threshold, return_distance=False) self.images_paths = list(self.database_paths) + list(self.queries_paths) self.database_num = len(self.database_paths) self.queries_num = len(self.queries_paths) def __getitem__(self, index): img = path_to_pil_img(self.images_paths[index]) img = base_transform(img) # With database images self.test_method should always be "hard_resize" if self.test_method == "hard_resize": # self.test_method=="hard_resize" is the default, resizes all images to the same size. img = T.functional.resize(img, self.resize) else: img = self._test_query_transform(img) return img, index def _test_query_transform(self, img): """Transform query image according to self.test_method.""" C, H, W = img.shape if self.test_method == "single_query": # self.test_method=="single_query" is used when queries have varying sizes, and can't be stacked in a batch. processed_img = T.functional.resize(img, min(self.resize)) elif self.test_method == "central_crop": # Take the biggest central crop of size self.resize. Preserves ratio. scale = max(self.resize[0]/H, self.resize[1]/W) processed_img = torch.nn.functional.interpolate(img.unsqueeze(0), scale_factor=scale).squeeze(0) processed_img = T.functional.center_crop(processed_img, self.resize) assert processed_img.shape[1:] == torch.Size(self.resize), f"{processed_img.shape[1:]} {self.resize}" elif self.test_method == "five_crops" or self.test_method == 'nearest_crop' or self.test_method == 'maj_voting': # Get 5 square crops with size==shorter_side (usually 480). Preserves ratio and allows batches. shorter_side = min(self.resize) processed_img = T.functional.resize(img, shorter_side) processed_img = torch.stack(T.functional.five_crop(processed_img, shorter_side)) assert processed_img.shape == torch.Size([5, 3, shorter_side, shorter_side]), \ f"{processed_img.shape} {torch.Size([5, 3, shorter_side, shorter_side])}" return processed_img def __len__(self): return len(self.images_paths) def __repr__(self): return f"< {self.__class__.__name__}, {self.dataset_name} - #database: {self.database_num}; #queries: {self.queries_num} >" def get_positives(self): return self.soft_positives_per_query class TripletsDataset(BaseDataset): """Dataset used for training, it is used to compute the triplets with TripletsDataset.compute_triplets() with various mining methods. If is_inference == True, uses methods of the parent class BaseDataset, this is used for example when computing the cache, because we compute features of each image, not triplets. """ def __init__(self, args, datasets_folder="datasets", dataset_name="pitts30k", split="train", negs_num_per_query=10): super().__init__(args, datasets_folder, dataset_name, split) self.mining = args.mining self.neg_samples_num = args.neg_samples_num # Number of negatives to randomly sample self.negs_num_per_query = negs_num_per_query # Number of negatives per query in each batch if self.mining == "full": # "Full database mining" keeps a cache with last used negatives self.neg_cache = [np.empty((0,), dtype=np.int32) for _ in range(self.queries_num)] self.is_inference = False identity_transform = T.Lambda(lambda x: x) self.resized_transform = T.Compose([ T.Resize(self.resize) if self.resize is not None else identity_transform, base_transform ]) self.query_transform = T.Compose([ T.ColorJitter(args.brightness, args.contrast, args.saturation, args.hue), T.RandomPerspective(args.rand_perspective), T.RandomResizedCrop(size=self.resize, scale=(1-args.random_resized_crop, 1)), T.RandomRotation(degrees=args.random_rotation), self.resized_transform, ]) # Find hard_positives_per_query, which are within train_positives_dist_threshold (10 meters) knn = NearestNeighbors(n_jobs=-1) knn.fit(self.database_utms) self.hard_positives_per_query = list(knn.radius_neighbors(self.queries_utms, radius=args.train_positives_dist_threshold, # 10 meters return_distance=False)) #### Some queries might have no positive, we should remove those queries. queries_without_any_hard_positive = np.where(np.array([len(p) for p in self.hard_positives_per_query], dtype=object) == 0)[0] if len(queries_without_any_hard_positive) != 0: logging.info(f"There are {len(queries_without_any_hard_positive)} queries without any positives " + "within the training set. They won't be considered as they're useless for training.") # Remove queries without positives self.hard_positives_per_query = np.delete(self.hard_positives_per_query, queries_without_any_hard_positive) self.soft_positives_per_query = np.delete(self.soft_positives_per_query, queries_without_any_hard_positive) self.queries_paths = np.delete(self.queries_paths, queries_without_any_hard_positive) # Recompute images_paths and queries_num because some queries might have been removed self.images_paths = list(self.database_paths) + list(self.queries_paths) self.queries_num = len(self.queries_paths) # msls_weighted refers to the mining presented in MSLS paper's supplementary. # Basically, images from uncommon domains are sampled more often. Works only with MSLS dataset. if self.mining == "msls_weighted": notes = [p.split("@")[-2] for p in self.queries_paths] try: night_indexes = np.where(np.array([n.split("_")[0] == "night" for n in notes]))[0] sideways_indexes = np.where(np.array([n.split("_")[1] == "sideways" for n in notes]))[0] except IndexError: raise RuntimeError("You're using msls_weighted mining but this dataset " + "does not have night/sideways information. Are you using Mapillary SLS?") self.weights = np.ones(self.queries_num) assert len(night_indexes) != 0 and len(sideways_indexes) != 0, \ "There should be night and sideways images for msls_weighted mining, but there are none. Are you using Mapillary SLS?" self.weights[night_indexes] += self.queries_num / len(night_indexes) self.weights[sideways_indexes] += self.queries_num / len(sideways_indexes) self.weights /= self.weights.sum() logging.info(f"#sideways_indexes [{len(sideways_indexes)}/{self.queries_num}]; " + "#night_indexes; [{len(night_indexes)}/{self.queries_num}]") def __getitem__(self, index): if self.is_inference: # At inference time return the single image. This is used for caching or computing NetVLAD's clusters return super().__getitem__(index) query_index, best_positive_index, neg_indexes = torch.split(self.triplets_global_indexes[index], (1, 1, self.negs_num_per_query)) query = self.query_transform(path_to_pil_img(self.queries_paths[query_index])) positive = self.resized_transform(path_to_pil_img(self.database_paths[best_positive_index])) negatives = [self.resized_transform(path_to_pil_img(self.database_paths[i])) for i in neg_indexes] images = torch.stack((query, positive, *negatives), 0) triplets_local_indexes = torch.empty((0, 3), dtype=torch.int) for neg_num in range(len(neg_indexes)): triplets_local_indexes = torch.cat((triplets_local_indexes, torch.tensor([0, 1, 2 + neg_num]).reshape(1, 3))) return images, triplets_local_indexes, self.triplets_global_indexes[index] def __len__(self): if self.is_inference: # At inference time return the number of images. This is used for caching or computing NetVLAD's clusters return super().__len__() else: return len(self.triplets_global_indexes) def compute_triplets(self, args, model): self.is_inference = True if self.mining == "full": self.compute_triplets_full(args, model) elif self.mining == "partial" or self.mining == "msls_weighted": self.compute_triplets_partial(args, model) elif self.mining == "random": self.compute_triplets_random(args, model) @staticmethod def compute_cache(args, model, subset_ds, cache_shape): """Compute the cache containing features of images, which is used to find best positive and hardest negatives.""" subset_dl = DataLoader(dataset=subset_ds, num_workers=args.num_workers, batch_size=args.infer_batch_size, shuffle=False, pin_memory=(args.device == "cuda")) model = model.eval() # RAMEfficient2DMatrix can be replaced by np.zeros, but using # RAMEfficient2DMatrix is RAM efficient for full database mining. cache = RAMEfficient2DMatrix(cache_shape, dtype=np.float32) with torch.no_grad(): for images, indexes in tqdm(subset_dl, ncols=100): images = images.to(args.device) features = model(images) cache[indexes.numpy()] = features.cpu().numpy() return cache def get_query_features(self, query_index, cache): query_features = cache[query_index + self.database_num] if query_features is None: raise RuntimeError(f"For query {self.queries_paths[query_index]} " + f"with index {query_index} features have not been computed!\n" + "There might be some bug with caching") return query_features def get_best_positive_index(self, args, query_index, cache, query_features): positives_features = cache[self.hard_positives_per_query[query_index]] faiss_index = faiss.IndexFlatL2(args.features_dim) faiss_index.add(positives_features) # Search the best positive (within 10 meters AND nearest in features space) _, best_positive_num = faiss_index.search(query_features.reshape(1, -1), 1) best_positive_index = self.hard_positives_per_query[query_index][best_positive_num[0]].item() return best_positive_index def get_hardest_negatives_indexes(self, args, cache, query_features, neg_samples): neg_features = cache[neg_samples] faiss_index = faiss.IndexFlatL2(args.features_dim) faiss_index.add(neg_features) # Search the 10 nearest negatives (further than 25 meters and nearest in features space) _, neg_nums = faiss_index.search(query_features.reshape(1, -1), self.negs_num_per_query) neg_nums = neg_nums.reshape(-1) neg_indexes = neg_samples[neg_nums].astype(np.int32) return neg_indexes def compute_triplets_random(self, args, model): self.triplets_global_indexes = [] # Take 1000 random queries sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False) # Take all the positives positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes] positives_indexes = [p for pos in positives_indexes for p in pos] # Flatten list of lists to a list positives_indexes = list(np.unique(positives_indexes)) # Compute the cache only for queries and their positives, in order to find the best positive subset_ds = Subset(self, positives_indexes + list(sampled_queries_indexes + self.database_num)) cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim)) # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results) for query_index in tqdm(sampled_queries_indexes, ncols=100): query_features = self.get_query_features(query_index, cache) best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features) # Choose some random database images, from those remove the soft_positives, and then take the first 10 images as neg_indexes soft_positives = self.soft_positives_per_query[query_index] neg_indexes = np.random.choice(self.database_num, size=self.negs_num_per_query+len(soft_positives), replace=False) neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True)[:self.negs_num_per_query] self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes)) # self.triplets_global_indexes is a tensor of shape [1000, 12] self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes) def compute_triplets_full(self, args, model): self.triplets_global_indexes = [] # Take 1000 random queries sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False) # Take all database indexes database_indexes = list(range(self.database_num)) # Compute features for all images and store them in cache subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num)) cache = self.compute_cache(args, model, subset_ds, (len(self), args.features_dim)) # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results) for query_index in tqdm(sampled_queries_indexes, ncols=100): query_features = self.get_query_features(query_index, cache) best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features) # Choose 1000 random database images (neg_indexes) neg_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False) # Remove the eventual soft_positives from neg_indexes soft_positives = self.soft_positives_per_query[query_index] neg_indexes = np.setdiff1d(neg_indexes, soft_positives, assume_unique=True) # Concatenate neg_indexes with the previous top 10 negatives (neg_cache) neg_indexes = np.unique(np.concatenate([self.neg_cache[query_index], neg_indexes])) # Search the hardest negatives neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes) # Update nearest negatives in neg_cache self.neg_cache[query_index] = neg_indexes self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes)) # self.triplets_global_indexes is a tensor of shape [1000, 12] self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes) def compute_triplets_partial(self, args, model): self.triplets_global_indexes = [] # Take 1000 random queries if self.mining == "partial": sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False) elif self.mining == "msls_weighted": # Pick night and sideways queries with higher probability sampled_queries_indexes = np.random.choice(self.queries_num, args.cache_refresh_rate, replace=False, p=self.weights) # Sample 1000 random database images for the negatives sampled_database_indexes = np.random.choice(self.database_num, self.neg_samples_num, replace=False) # Take all the positives positives_indexes = [self.hard_positives_per_query[i] for i in sampled_queries_indexes] positives_indexes = [p for pos in positives_indexes for p in pos] # Merge them into database_indexes and remove duplicates database_indexes = list(sampled_database_indexes) + positives_indexes database_indexes = list(np.unique(database_indexes)) subset_ds = Subset(self, database_indexes + list(sampled_queries_indexes + self.database_num)) cache = self.compute_cache(args, model, subset_ds, cache_shape=(len(self), args.features_dim)) # This loop's iterations could be done individually in the __getitem__(). This way is slower but clearer (and yields same results) for query_index in tqdm(sampled_queries_indexes, ncols=100): query_features = self.get_query_features(query_index, cache) best_positive_index = self.get_best_positive_index(args, query_index, cache, query_features) # Choose the hardest negatives within sampled_database_indexes, ensuring that there are no positives soft_positives = self.soft_positives_per_query[query_index] neg_indexes = np.setdiff1d(sampled_database_indexes, soft_positives, assume_unique=True) # Take all database images that are negatives and are within the sampled database images (aka database_indexes) neg_indexes = self.get_hardest_negatives_indexes(args, cache, query_features, neg_indexes) self.triplets_global_indexes.append((query_index, best_positive_index, *neg_indexes)) # self.triplets_global_indexes is a tensor of shape [1000, 12] self.triplets_global_indexes = torch.tensor(self.triplets_global_indexes) class RAMEfficient2DMatrix: """This class behaves similarly to a numpy.ndarray initialized with np.zeros(), but is implemented to save RAM when the rows within the 2D array are sparse. In this case it's needed because we don't always compute features for each image, just for few of them""" def __init__(self, shape, dtype=np.float32): self.shape = shape self.dtype = dtype self.matrix = [None] * shape[0] def __setitem__(self, indexes, vals): assert vals.shape[1] == self.shape[1], f"{vals.shape[1]} {self.shape[1]}" for i, val in zip(indexes, vals): self.matrix[i] = val.astype(self.dtype, copy=False) def __getitem__(self, index): if hasattr(index, "__len__"): return np.array([self.matrix[i] for i in index]) else: return self.matrix[index] ================================================ FILE: eval.py ================================================ """ With this script you can evaluate checkpoints or test models from two popular landmark retrieval github repos. The first is https://github.com/naver/deep-image-retrieval from Naver labs, provides ResNet-50 and ResNet-101 trained with AP on Google Landmarks 18 clean. $ python eval.py --off_the_shelf=naver --l2=none --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048 The second is https://github.com/filipradenovic/cnnimageretrieval-pytorch from Radenovic, provides ResNet-50 and ResNet-101 trained with a triplet loss on Google Landmarks 18 and sfm120k. $ python eval.py --off_the_shelf=radenovic_gldv1 --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048 $ python eval.py --off_the_shelf=radenovic_sfm --l2=after_pool --backbone=resnet101conv5 --aggregation=gem --fc_output_dim=2048 Note that although the architectures are almost the same, Naver's implementation does not use a l2 normalization before/after the GeM aggregation, while Radenovic's uses it after (and we use it before, which shows better results in VG) """ import os import sys import torch import parser import logging import sklearn from os.path import join from datetime import datetime from torch.utils.model_zoo import load_url from google_drive_downloader import GoogleDriveDownloader as gdd import test import util import commons import datasets_ws from model import network OFF_THE_SHELF_RADENOVIC = { 'resnet50conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 'resnet101conv5_sfm' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 'resnet50conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 'resnet101conv5_gldv1' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', } OFF_THE_SHELF_NAVER = { "resnet50conv5" : "1oPtE_go9tnsiDLkWjN4NMpKjh-_md1G5", 'resnet101conv5' : "1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy" } ######################################### SETUP ######################################### args = parser.parse_arguments() start_time = datetime.now() args.save_dir = join("test", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S')) commons.setup_logging(args.save_dir) commons.make_deterministic(args.seed) logging.info(f"Arguments: {args}") logging.info(f"The outputs are being saved in {args.save_dir}") ######################################### MODEL ######################################### model = network.GeoLocalizationNet(args) model = model.to(args.device) if args.aggregation in ["netvlad", "crn"]: args.features_dim *= args.netvlad_clusters if args.off_the_shelf.startswith("radenovic") or args.off_the_shelf.startswith("naver"): if args.off_the_shelf.startswith("radenovic"): pretrain_dataset_name = args.off_the_shelf.split("_")[1] # sfm or gldv1 datasets url = OFF_THE_SHELF_RADENOVIC[f"{args.backbone}_{pretrain_dataset_name}"] state_dict = load_url(url, model_dir=join("data", "off_the_shelf_nets")) else: # This is a hacky workaround to maintain compatibility sys.modules['sklearn.decomposition.pca'] = sklearn.decomposition._pca zip_file_path = join("data", "off_the_shelf_nets", args.backbone + "_naver.zip") if not os.path.exists(zip_file_path): gdd.download_file_from_google_drive(file_id=OFF_THE_SHELF_NAVER[args.backbone], dest_path=zip_file_path, unzip=True) if args.backbone == "resnet50conv5": state_dict_filename = "Resnet50-AP-GeM.pt" elif args.backbone == "resnet101conv5": state_dict_filename = "Resnet-101-AP-GeM.pt" state_dict = torch.load(join("data", "off_the_shelf_nets", state_dict_filename)) state_dict = state_dict["state_dict"] model_keys = model.state_dict().keys() renamed_state_dict = {k: v for k, v in zip(model_keys, state_dict.values())} model.load_state_dict(renamed_state_dict) elif args.resume is not None: logging.info(f"Resuming model from {args.resume}") model = util.resume_model(args, model) # Enable DataParallel after loading checkpoint, otherwise doing it before # would append "module." in front of the keys of the state dict triggering errors model = torch.nn.DataParallel(model) if args.pca_dim is None: pca = None else: full_features_dim = args.features_dim args.features_dim = args.pca_dim pca = util.compute_pca(args, model, args.pca_dataset_folder, full_features_dim) ######################################### DATASETS ######################################### test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test") logging.info(f"Test set: {test_ds}") ######################################### TEST on TEST SET ######################################### recalls, recalls_str = test.test(args, test_ds, model, args.test_method, pca) logging.info(f"Recalls on {test_ds}: {recalls_str}") logging.info(f"Finished in {str(datetime.now() - start_time)[:-7]}") ================================================ FILE: model/__init__.py ================================================ ================================================ FILE: model/aggregation.py ================================================ import math import torch import faiss import logging import numpy as np from tqdm import tqdm import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.utils.data import DataLoader, SubsetRandomSampler import model.functional as LF import model.normalization as normalization class MAC(nn.Module): def __init__(self): super().__init__() def forward(self, x): return LF.mac(x) def __repr__(self): return self.__class__.__name__ + '()' class SPoC(nn.Module): def __init__(self): super().__init__() def forward(self, x): return LF.spoc(x) def __repr__(self): return self.__class__.__name__ + '()' class GeM(nn.Module): def __init__(self, p=3, eps=1e-6, work_with_tokens=False): super().__init__() self.p = Parameter(torch.ones(1)*p) self.eps = eps self.work_with_tokens=work_with_tokens def forward(self, x): return LF.gem(x, p=self.p, eps=self.eps, work_with_tokens=self.work_with_tokens) def __repr__(self): return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' class RMAC(nn.Module): def __init__(self, L=3, eps=1e-6): super().__init__() self.L = L self.eps = eps def forward(self, x): return LF.rmac(x, L=self.L, eps=self.eps) def __repr__(self): return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' class Flatten(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): assert x.shape[2] == x.shape[3] == 1; return x[:,:,0,0] class RRM(nn.Module): """Residual Retrieval Module as described in the paper `Leveraging EfficientNet and Contrastive Learning for AccurateGlobal-scale Location Estimation ` """ def __init__(self, dim): super().__init__() self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = Flatten() self.ln1 = nn.LayerNorm(normalized_shape=dim) self.fc1 = nn.Linear(in_features=dim, out_features=dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(in_features=dim, out_features=dim) self.ln2 = nn.LayerNorm(normalized_shape=dim) self.l2 = normalization.L2Norm() def forward(self, x): x = self.avgpool(x) x = self.flatten(x) x = self.ln1(x) identity = x out = self.fc2(self.relu(self.fc1(x))) out += identity out = self.l2(self.ln2(out)) return out # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py class NetVLAD(nn.Module): """NetVLAD layer implementation""" def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False): """ Args: clusters_num : int The number of clusters dim : int Dimension of descriptors alpha : float Parameter of initialization. Larger value is harder assignment. normalize_input : bool If true, descriptor-wise L2 normalization is applied to input. """ super().__init__() self.clusters_num = clusters_num self.dim = dim self.alpha = 0 self.normalize_input = normalize_input self.work_with_tokens = work_with_tokens if work_with_tokens: self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False) else: self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False) self.centroids = nn.Parameter(torch.rand(clusters_num, dim)) def init_params(self, centroids, descriptors): centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True) dots = np.dot(centroids_assign, descriptors.T) dots.sort(0) dots = dots[::-1, :] # sort, descending self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() self.centroids = nn.Parameter(torch.from_numpy(centroids)) if self.work_with_tokens: self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2)) else: self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3)) self.conv.bias = None def forward(self, x): if self.work_with_tokens: x = x.permute(0, 2, 1) N, D, _ = x.shape[:] else: N, D, H, W = x.shape[:] if self.normalize_input: x = F.normalize(x, p=2, dim=1) # Across descriptor dim x_flatten = x.view(N, D, -1) soft_assign = self.conv(x).view(N, self.clusters_num, -1) soft_assign = F.softmax(soft_assign, dim=1) vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2) vlad[:,D:D+1,:] = residual.sum(dim=-1) vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization vlad = vlad.view(N, -1) # Flatten vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize return vlad def initialize_netvlad_layer(self, args, cluster_ds, backbone): descriptors_num = 50000 descs_num_per_image = 100 images_num = math.ceil(descriptors_num / descs_num_per_image) random_sampler = SubsetRandomSampler(np.random.choice(len(cluster_ds), images_num, replace=False)) random_dl = DataLoader(dataset=cluster_ds, num_workers=args.num_workers, batch_size=args.infer_batch_size, sampler=random_sampler) with torch.no_grad(): backbone = backbone.eval() logging.debug("Extracting features to initialize NetVLAD layer") descriptors = np.zeros(shape=(descriptors_num, args.features_dim), dtype=np.float32) for iteration, (inputs, _) in enumerate(tqdm(random_dl, ncols=100)): inputs = inputs.to(args.device) outputs = backbone(inputs) norm_outputs = F.normalize(outputs, p=2, dim=1) image_descriptors = norm_outputs.view(norm_outputs.shape[0], args.features_dim, -1).permute(0, 2, 1) image_descriptors = image_descriptors.cpu().numpy() batchix = iteration * args.infer_batch_size * descs_num_per_image for ix in range(image_descriptors.shape[0]): sample = np.random.choice(image_descriptors.shape[1], descs_num_per_image, replace=False) startix = batchix + ix * descs_num_per_image descriptors[startix:startix + descs_num_per_image, :] = image_descriptors[ix, sample, :] kmeans = faiss.Kmeans(args.features_dim, self.clusters_num, niter=100, verbose=False) kmeans.train(descriptors) logging.debug(f"NetVLAD centroids shape: {kmeans.centroids.shape}") self.init_params(kmeans.centroids, descriptors) self = self.to(args.device) class CRNModule(nn.Module): def __init__(self, dim): super().__init__() # Downsample pooling self.downsample_pool = nn.AvgPool2d(kernel_size=3, stride=(2, 2), padding=0, ceil_mode=True) # Multiscale Context Filters self.filter_3_3 = nn.Conv2d(in_channels=dim, out_channels=32, kernel_size=(3, 3), padding=1) self.filter_5_5 = nn.Conv2d(in_channels=dim, out_channels=32, kernel_size=(5, 5), padding=2) self.filter_7_7 = nn.Conv2d(in_channels=dim, out_channels=20, kernel_size=(7, 7), padding=3) # Accumulation weight self.acc_w = nn.Conv2d(in_channels=84, out_channels=1, kernel_size=(1, 1)) # Upsampling self.upsample = F.interpolate self._initialize_weights() def _initialize_weights(self): # Initialize Context Filters torch.nn.init.xavier_normal_(self.filter_3_3.weight) torch.nn.init.constant_(self.filter_3_3.bias, 0.0) torch.nn.init.xavier_normal_(self.filter_5_5.weight) torch.nn.init.constant_(self.filter_5_5.bias, 0.0) torch.nn.init.xavier_normal_(self.filter_7_7.weight) torch.nn.init.constant_(self.filter_7_7.bias, 0.0) torch.nn.init.constant_(self.acc_w.weight, 1.0) torch.nn.init.constant_(self.acc_w.bias, 0.0) self.acc_w.weight.requires_grad = False self.acc_w.bias.requires_grad = False def forward(self, x): # Contextual Reweighting Network x_crn = self.downsample_pool(x) # Compute multiscale context filters g_n g_3 = self.filter_3_3(x_crn) g_5 = self.filter_5_5(x_crn) g_7 = self.filter_7_7(x_crn) g = torch.cat((g_3, g_5, g_7), dim=1) g = F.relu(g) w = F.relu(self.acc_w(g)) # Accumulation weight mask = self.upsample(w, scale_factor=2, mode='bilinear') # Reweighting Mask return mask class CRN(NetVLAD): def __init__(self, clusters_num=64, dim=128, normalize_input=True): super().__init__(clusters_num, dim, normalize_input) self.crn = CRNModule(dim) def forward(self, x): N, D, H, W = x.shape[:] if self.normalize_input: x = F.normalize(x, p=2, dim=1) # Across descriptor dim mask = self.crn(x) x_flatten = x.view(N, D, -1) soft_assign = self.conv(x).view(N, self.clusters_num, -1) soft_assign = F.softmax(soft_assign, dim=1) # Weight soft_assign using CRN's mask soft_assign = soft_assign * mask.view(N, 1, H * W) vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ self.centroids[D:D + 1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) residual = residual * soft_assign[:, D:D + 1, :].unsqueeze(2) vlad[:, D:D + 1, :] = residual.sum(dim=-1) vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization vlad = vlad.view(N, -1) # Flatten vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize return vlad ================================================ FILE: model/cct/__init__.py ================================================ from .cct import cct_14_7x2_384, cct_14_7x2_224 ================================================ FILE: model/cct/cct.py ================================================ from torch.hub import load_state_dict_from_url import torch.nn as nn import torch import torch.nn.functional as F from .transformers import TransformerClassifier from .tokenizer import Tokenizer from .helpers import pe_check from timm.models.registry import register_model model_urls = { 'cct_7_3x1_32': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar10_300epochs.pth', 'cct_7_3x1_32_sine': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar10_5000epochs.pth', 'cct_7_3x1_32_c100': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar100_300epochs.pth', 'cct_7_3x1_32_sine_c100': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar100_5000epochs.pth', 'cct_7_7x2_224_sine': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_7x2_224_flowers102.pth', 'cct_14_7x2_224': 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_14_7x2_224_imagenet.pth', 'cct_14_7x2_384': 'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_imagenet.pth', 'cct_14_7x2_384_fl': 'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_flowers102.pth', } class CCT(nn.Module): def __init__(self, img_size=224, embedding_dim=768, n_input_channels=3, n_conv_layers=1, kernel_size=7, stride=2, padding=3, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, dropout=0., attention_dropout=0.1, stochastic_depth=0.1, num_layers=14, num_heads=6, mlp_ratio=4.0, num_classes=1000, positional_embedding='learnable', aggregation=None, *args, **kwargs): super(CCT, self).__init__() self.tokenizer = Tokenizer(n_input_channels=n_input_channels, n_output_channels=embedding_dim, kernel_size=kernel_size, stride=stride, padding=padding, pooling_kernel_size=pooling_kernel_size, pooling_stride=pooling_stride, pooling_padding=pooling_padding, max_pool=True, activation=nn.ReLU, n_conv_layers=n_conv_layers, conv_bias=False) self.classifier = TransformerClassifier( sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels, height=img_size, width=img_size), embedding_dim=embedding_dim, seq_pool=True, dropout=dropout, attention_dropout=attention_dropout, stochastic_depth=stochastic_depth, num_layers=num_layers, num_heads=num_heads, mlp_ratio=mlp_ratio, num_classes=num_classes, positional_embedding=positional_embedding ) if aggregation in ['cls', 'seqpool']: self.aggregation = aggregation else: self.aggregation = None def forward(self, x): x = self.tokenizer(x) x = self.classifier(x) if self.aggregation == 'cls': return x[:, 0] elif self.aggregation == 'seqpool': x = torch.matmul(F.softmax(self.classifier.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) return x else: # x = x.permute(0, 2, 1) return x def _cct(arch, pretrained, progress, num_layers, num_heads, mlp_ratio, embedding_dim, kernel_size=3, stride=None, padding=None, aggregation=None, *args, **kwargs): stride = stride if stride is not None else max(1, (kernel_size // 2) - 1) padding = padding if padding is not None else max(1, (kernel_size // 2)) model = CCT(num_layers=num_layers, num_heads=num_heads, mlp_ratio=mlp_ratio, embedding_dim=embedding_dim, kernel_size=kernel_size, stride=stride, padding=padding, aggregation=aggregation, *args, **kwargs) if pretrained: if arch in model_urls: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) state_dict = pe_check(model, state_dict) model.load_state_dict(state_dict, strict=False) else: raise RuntimeError(f'Variant {arch} does not yet have pretrained weights.') return model def cct_2(arch, pretrained, progress, aggregation=None, *args, **kwargs): return _cct(arch, pretrained, progress, num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128, aggregation=aggregation, *args, **kwargs) def cct_4(arch, pretrained, progress, aggregation=None, *args, **kwargs): return _cct(arch, pretrained, progress, num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128, aggregation=aggregation, *args, **kwargs) def cct_6(arch, pretrained, progress, aggregation=None, *args, **kwargs): return _cct(arch, pretrained, progress, num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256, aggregation=aggregation, *args, **kwargs) def cct_7(arch, pretrained, progress, aggregation=None, *args, **kwargs): return _cct(arch, pretrained, progress, num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256, aggregation=aggregation, *args, **kwargs) def cct_14(arch, pretrained, progress, aggregation=None, *args, **kwargs): return _cct(arch, pretrained, progress, num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384, aggregation=aggregation, *args, **kwargs) @register_model def cct_2_3x2_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_2('cct_2_3x2_32', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_2_3x2_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_2('cct_2_3x2_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_4_3x2_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_4('cct_4_3x2_32', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_4_3x2_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_4('cct_4_3x2_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_6_3x1_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_6('cct_6_3x1_32', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_6_3x1_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_6('cct_6_3x1_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_6_3x2_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_6('cct_6_3x2_32', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_6_3x2_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_6('cct_6_3x2_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x1_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x1_32', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x1_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x1_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x1_32_c100(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=100, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x1_32_c100', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x1_32_sine_c100(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=100, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x1_32_sine_c100', pretrained, progress, kernel_size=3, n_conv_layers=1, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x2_32(pretrained=False, progress=False, img_size=32, positional_embedding='learnable', num_classes=10, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x2_32', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_3x2_32_sine(pretrained=False, progress=False, img_size=32, positional_embedding='sine', num_classes=10, aggregation=None, *args, **kwargs): return cct_7('cct_7_3x2_32_sine', pretrained, progress, kernel_size=3, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_7x2_224(pretrained=False, progress=False, img_size=224, positional_embedding='learnable', num_classes=102, aggregation=None, *args, **kwargs): return cct_7('cct_7_7x2_224', pretrained, progress, kernel_size=7, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_7_7x2_224_sine(pretrained=False, progress=False, img_size=224, positional_embedding='sine', num_classes=102, aggregation=None, *args, **kwargs): return cct_7('cct_7_7x2_224_sine', pretrained, progress, kernel_size=7, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_14_7x2_224(pretrained=False, progress=False, img_size=224, positional_embedding='learnable', num_classes=1000, aggregation=None, *args, **kwargs): return cct_14('cct_14_7x2_224', pretrained, progress, kernel_size=7, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_14_7x2_384(pretrained=False, progress=False, img_size=384, positional_embedding='learnable', num_classes=1000, aggregation=None, *args, **kwargs): return cct_14('cct_14_7x2_384', pretrained, progress, kernel_size=7, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) @register_model def cct_14_7x2_384_fl(pretrained=False, progress=False, img_size=384, positional_embedding='learnable', num_classes=102, aggregation=None, *args, **kwargs): return cct_14('cct_14_7x2_384_fl', pretrained, progress, kernel_size=7, n_conv_layers=2, img_size=img_size, positional_embedding=positional_embedding, num_classes=num_classes, aggregation=aggregation, *args, **kwargs) ================================================ FILE: model/cct/embedder.py ================================================ import torch.nn as nn class Embedder(nn.Module): def __init__(self, word_embedding_dim=300, vocab_size=100000, padding_idx=1, pretrained_weight=None, embed_freeze=False, *args, **kwargs): super(Embedder, self).__init__() self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \ if pretrained_weight is not None else \ nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) self.embeddings.weight.requires_grad = not embed_freeze def forward_mask(self, mask): bsz, seq_len = mask.shape new_mask = mask.view(bsz, seq_len, 1) new_mask = new_mask.sum(-1) new_mask = (new_mask > 0) return new_mask def forward(self, x, mask=None): embed = self.embeddings(x) embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float() return embed, mask @staticmethod def init_weight(m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) else: nn.init.normal_(m.weight) ================================================ FILE: model/cct/helpers.py ================================================ import math import torch import torch.nn.functional as F def resize_pos_embed(posemb, posemb_new, num_tokens=1): # Copied from `timm` by Ross Wightman: # github.com/rwightman/pytorch-image-models # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 ntok_new = posemb_new.shape[1] if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= num_tokens else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) gs_new = int(math.sqrt(ntok_new)) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def pe_check(model, state_dict, pe_key='classifier.positional_emb'): if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys(): if model.state_dict()[pe_key].shape != state_dict[pe_key].shape: state_dict[pe_key] = resize_pos_embed(state_dict[pe_key], model.state_dict()[pe_key], num_tokens=model.classifier.num_tokens) return state_dict ================================================ FILE: model/cct/stochastic_depth.py ================================================ # Thanks to rwightman's timm package # github.com:rwightman/pytorch-image-models import torch import torch.nn as nn def drop_path(x, drop_prob: float = 0., training: bool = False): """ Obtained from: github.com:rwightman/pytorch-image-models Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """ Obtained from: github.com:rwightman/pytorch-image-models Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: model/cct/tokenizer.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class Tokenizer(nn.Module): def __init__(self, kernel_size, stride, padding, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, n_conv_layers=1, n_input_channels=3, n_output_channels=64, in_planes=64, activation=None, max_pool=True, conv_bias=False): super(Tokenizer, self).__init__() n_filter_list = [n_input_channels] + \ [in_planes for _ in range(n_conv_layers - 1)] + \ [n_output_channels] self.conv_layers = nn.Sequential( *[nn.Sequential( nn.Conv2d(n_filter_list[i], n_filter_list[i + 1], kernel_size=(kernel_size, kernel_size), stride=(stride, stride), padding=(padding, padding), bias=conv_bias), nn.Identity() if activation is None else activation(), nn.MaxPool2d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding) if max_pool else nn.Identity() ) for i in range(n_conv_layers) ]) self.flattener = nn.Flatten(2, 3) self.apply(self.init_weight) def sequence_length(self, n_channels=3, height=224, width=224): return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] def forward(self, x): return self.flattener(self.conv_layers(x)).transpose(-2, -1) @staticmethod def init_weight(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) class TextTokenizer(nn.Module): def __init__(self, kernel_size, stride, padding, pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, embedding_dim=300, n_output_channels=128, activation=None, max_pool=True, *args, **kwargs): super(TextTokenizer, self).__init__() self.max_pool = max_pool self.conv_layers = nn.Sequential( nn.Conv2d(1, n_output_channels, kernel_size=(kernel_size, embedding_dim), stride=(stride, 1), padding=(padding, 0), bias=False), nn.Identity() if activation is None else activation(), nn.MaxPool2d( kernel_size=(pooling_kernel_size, 1), stride=(pooling_stride, 1), padding=(pooling_padding, 0) ) if max_pool else nn.Identity() ) self.apply(self.init_weight) def seq_len(self, seq_len=32, embed_dim=300): return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] def forward_mask(self, mask): new_mask = mask.unsqueeze(1).float() cnn_weight = torch.ones( (1, 1, self.conv_layers[0].kernel_size[0]), device=mask.device, dtype=torch.float) new_mask = F.conv1d( new_mask, cnn_weight, None, self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1) if self.max_pool: new_mask = F.max_pool1d( new_mask, self.conv_layers[2].kernel_size[0], self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False) new_mask = new_mask.squeeze(1) new_mask = (new_mask > 0) return new_mask def forward(self, x, mask=None): x = x.unsqueeze(1) x = self.conv_layers(x) x = x.transpose(1, 3).squeeze(1) x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float() return x, mask @staticmethod def init_weight(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) ================================================ FILE: model/cct/transformers.py ================================================ import torch from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init import torch.nn.functional as F from .stochastic_depth import DropPath class Attention(Module): """ Obtained from timm: github.com:rwightman/pytorch-image-models """ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): super().__init__() self.num_heads = num_heads head_dim = dim // self.num_heads self.scale = head_dim ** -0.5 self.qkv = Linear(dim, dim * 3, bias=False) self.attn_drop = Dropout(attention_dropout) self.proj = Linear(dim, dim) self.proj_drop = Dropout(projection_dropout) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MaskedAttention(Module): def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): super().__init__() self.num_heads = num_heads head_dim = dim // self.num_heads self.scale = head_dim ** -0.5 self.qkv = Linear(dim, dim * 3, bias=False) self.attn_drop = Dropout(attention_dropout) self.proj = Linear(dim, dim) self.proj_drop = Dropout(projection_dropout) def forward(self, x, mask=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: mask_value = -torch.finfo(attn.dtype).max assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) attn.masked_fill_(~mask, mask_value) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class TransformerEncoderLayer(Module): """ Inspired by torch.nn.TransformerEncoderLayer and timm. """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, attention_dropout=0.1, drop_path_rate=0.1): super(TransformerEncoderLayer, self).__init__() self.pre_norm = LayerNorm(d_model) self.self_attn = Attention(dim=d_model, num_heads=nhead, attention_dropout=attention_dropout, projection_dropout=dropout) self.linear1 = Linear(d_model, dim_feedforward) self.dropout1 = Dropout(dropout) self.norm1 = LayerNorm(d_model) self.linear2 = Linear(dim_feedforward, d_model) self.dropout2 = Dropout(dropout) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() self.activation = F.gelu def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: src = src + self.drop_path(self.self_attn(self.pre_norm(src))) src = self.norm1(src) src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) src = src + self.drop_path(self.dropout2(src2)) return src class MaskedTransformerEncoderLayer(Module): """ Inspired by torch.nn.TransformerEncoderLayer and timm. """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, attention_dropout=0.1, drop_path_rate=0.1): super(MaskedTransformerEncoderLayer, self).__init__() self.pre_norm = LayerNorm(d_model) self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead, attention_dropout=attention_dropout, projection_dropout=dropout) self.linear1 = Linear(d_model, dim_feedforward) self.dropout1 = Dropout(dropout) self.norm1 = LayerNorm(d_model) self.linear2 = Linear(dim_feedforward, d_model) self.dropout2 = Dropout(dropout) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() self.activation = F.gelu def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor: src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask)) src = self.norm1(src) src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) src = src + self.drop_path(self.dropout2(src2)) return src class TransformerClassifier(Module): def __init__(self, seq_pool=True, embedding_dim=768, num_layers=12, num_heads=12, mlp_ratio=4.0, num_classes=1000, dropout=0.1, attention_dropout=0.1, stochastic_depth=0.1, positional_embedding='learnable', sequence_length=None): super().__init__() positional_embedding = positional_embedding if \ positional_embedding in ['sine', 'learnable', 'none'] else 'sine' dim_feedforward = int(embedding_dim * mlp_ratio) self.embedding_dim = embedding_dim self.sequence_length = sequence_length self.seq_pool = seq_pool assert sequence_length is not None or positional_embedding == 'none', \ f"Positional embedding is set to {positional_embedding} and" \ f" the sequence length was not specified." if not seq_pool: sequence_length += 1 self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True) else: self.attention_pool = Linear(self.embedding_dim, 1) if positional_embedding != 'none': if positional_embedding == 'learnable': self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), requires_grad=True) init.trunc_normal_(self.positional_emb, std=0.2) else: self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), requires_grad=False) else: self.positional_emb = None self.dropout = Dropout(p=dropout) dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] self.blocks = ModuleList([ TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, attention_dropout=attention_dropout, drop_path_rate=dpr[i]) for i in range(num_layers)]) self.norm = LayerNorm(embedding_dim) # self.fc = Linear(embedding_dim, num_classes) self.apply(self.init_weight) def forward(self, x): if self.positional_emb is None and x.size(1) < self.sequence_length: x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) if not self.seq_pool: cls_token = self.class_emb.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) if self.positional_emb is not None: x += self.positional_emb x = self.dropout(x) for blk in self.blocks: x = blk(x) x = self.norm(x) # TODO: TOREMOVE # if self.seq_pool: # x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) #else: # x = x[:, 0] # x = self.fc(x) return x @staticmethod def init_weight(m): if isinstance(m, Linear): init.trunc_normal_(m.weight, std=.02) if isinstance(m, Linear) and m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, LayerNorm): init.constant_(m.bias, 0) init.constant_(m.weight, 1.0) @staticmethod def sinusoidal_embedding(n_channels, dim): pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]) pe[:, 0::2] = torch.sin(pe[:, 0::2]) pe[:, 1::2] = torch.cos(pe[:, 1::2]) return pe.unsqueeze(0) class MaskedTransformerClassifier(Module): def __init__(self, seq_pool=True, embedding_dim=768, num_layers=12, num_heads=12, mlp_ratio=4.0, num_classes=1000, dropout=0.1, attention_dropout=0.1, stochastic_depth=0.1, positional_embedding='sine', seq_len=None, *args, **kwargs): super().__init__() positional_embedding = positional_embedding if \ positional_embedding in ['sine', 'learnable', 'none'] else 'sine' dim_feedforward = int(embedding_dim * mlp_ratio) self.embedding_dim = embedding_dim self.seq_len = seq_len self.seq_pool = seq_pool assert seq_len is not None or positional_embedding == 'none', \ f"Positional embedding is set to {positional_embedding} and" \ f" the sequence length was not specified." if not seq_pool: seq_len += 1 self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True) else: self.attention_pool = Linear(self.embedding_dim, 1) if positional_embedding != 'none': if positional_embedding == 'learnable': seq_len += 1 # padding idx self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim), requires_grad=True) init.trunc_normal_(self.positional_emb, std=0.2) else: self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len, embedding_dim, padding_idx=True), requires_grad=False) else: self.positional_emb = None self.dropout = Dropout(p=dropout) dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] self.blocks = ModuleList([ MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, attention_dropout=attention_dropout, drop_path_rate=dpr[i]) for i in range(num_layers)]) self.norm = LayerNorm(embedding_dim) self.fc = Linear(embedding_dim, num_classes) self.apply(self.init_weight) def forward(self, x, mask=None): if self.positional_emb is None and x.size(1) < self.seq_len: x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) if not self.seq_pool: cls_token = self.class_emb.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) if mask is not None: mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1) mask = (mask > 0) if self.positional_emb is not None: x += self.positional_emb x = self.dropout(x) for blk in self.blocks: x = blk(x, mask=mask) x = self.norm(x) if self.seq_pool: x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) else: x = x[:, 0] x = self.fc(x) return x @staticmethod def init_weight(m): if isinstance(m, Linear): init.trunc_normal_(m.weight, std=.02) if isinstance(m, Linear) and m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, LayerNorm): init.constant_(m.bias, 0) init.constant_(m.weight, 1.0) @staticmethod def sinusoidal_embedding(n_channels, dim, padding_idx=False): pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]) pe[:, 0::2] = torch.sin(pe[:, 0::2]) pe[:, 1::2] = torch.cos(pe[:, 1::2]) pe = pe.unsqueeze(0) if padding_idx: return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1) return pe ================================================ FILE: model/functional.py ================================================ import math import torch import torch.nn.functional as F def sare_ind(query, positive, negative): '''all 3 inputs are supposed to be shape 1xn_features''' dist_pos = ((query - positive)**2).sum(1) dist_neg = ((query - negative)**2).sum(1) dist = - torch.cat((dist_pos, dist_neg)) dist = F.log_softmax(dist, 0) #loss = (- dist[:, 0]).mean() on a batch loss = -dist[0] return loss def sare_joint(query, positive, negatives): '''query and positive have to be 1xn_features; whereas negatives has to be shape n_negative x n_features. n_negative is usually 10''' # NOTE: the implementation is the same if batch_size=1 as all operations # are vectorial. If there were the additional n_batch dimension a different # handling of that situation would have to be implemented here. # This function is declared anyway for the sake of clarity as the 2 should # be called in different situations because, even though there would be # no Exceptions, there would actually be a conceptual error. return sare_ind(query, positive, negatives) def mac(x): return F.adaptive_max_pool2d(x, (1,1)) def spoc(x): return F.adaptive_avg_pool2d(x, (1,1)) def gem(x, p=3, eps=1e-6, work_with_tokens=False): if work_with_tokens: x = x.permute(0, 2, 1) # unseqeeze to maintain compatibility with Flatten return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p).unsqueeze(3) else: return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) def rmac(x, L=3, eps=1e-6): ovr = 0.4 # desired overlap of neighboring regions steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension W = x.size(3) H = x.size(2) w = min(W, H) # w2 = math.floor(w/2.0 - 1) b = (max(H, W)-w)/(steps-1) (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension # region overplus per dimension Wd = 0; Hd = 0; if H < W: Wd = idx.item() + 1 elif H > W: Hd = idx.item() + 1 v = F.max_pool2d(x, (x.size(-2), x.size(-1))) v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) for l in range(1, L+1): wl = math.floor(2*w/(l+1)) wl2 = math.floor(wl/2 - 1) if l+Wd == 1: b = 0 else: b = (W-wl)/(l+Wd-1) cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates if l+Hd == 1: b = 0 else: b = (H-wl)/(l+Hd-1) cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates for i_ in cenH.tolist(): for j_ in cenW.tolist(): if wl == 0: continue R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) v += vt return v ================================================ FILE: model/network.py ================================================ import os import torch import logging import torchvision from torch import nn from os.path import join from transformers import ViTModel from google_drive_downloader import GoogleDriveDownloader as gdd from model.cct import cct_14_7x2_384 from model.aggregation import Flatten from model.normalization import L2Norm import model.aggregation as aggregation # Pretrained models on Google Landmarks v2 and Places 365 PRETRAINED_MODELS = { 'resnet18_places' : '1DnEQXhmPxtBUrRc81nAvT8z17bk-GBj5', 'resnet50_places' : '1zsY4mN4jJ-AsmV3h4hjbT72CBfJsgSGC', 'resnet101_places' : '1E1ibXQcg7qkmmmyYgmwMTh7Xf1cDNQXa', 'vgg16_places' : '1UWl1uz6rZ6Nqmp1K5z3GHAIZJmDh4bDu', 'resnet18_gldv2' : '1wkUeUXFXuPHuEvGTXVpuP5BMB-JJ1xke', 'resnet50_gldv2' : '1UDUv6mszlXNC1lv6McLdeBNMq9-kaA70', 'resnet101_gldv2' : '1apiRxMJpDlV0XmKlC5Na_Drg2jtGL-uE', 'vgg16_gldv2' : '10Ov9JdO7gbyz6mB5x0v_VSAUMj91Ta4o' } class GeoLocalizationNet(nn.Module): """The used networks are composed of a backbone and an aggregation layer. """ def __init__(self, args): super().__init__() self.backbone = get_backbone(args) self.arch_name = args.backbone self.aggregation = get_aggregation(args) if args.aggregation in ["gem", "spoc", "mac", "rmac"]: if args.l2 == "before_pool": self.aggregation = nn.Sequential(L2Norm(), self.aggregation, Flatten()) elif args.l2 == "after_pool": self.aggregation = nn.Sequential(self.aggregation, L2Norm(), Flatten()) elif args.l2 == "none": self.aggregation = nn.Sequential(self.aggregation, Flatten()) if args.fc_output_dim != None: # Concatenate fully connected layer to the aggregation layer self.aggregation = nn.Sequential(self.aggregation, nn.Linear(args.features_dim, args.fc_output_dim), L2Norm()) args.features_dim = args.fc_output_dim def forward(self, x): x = self.backbone(x) x = self.aggregation(x) return x def get_aggregation(args): if args.aggregation == "gem": return aggregation.GeM(work_with_tokens=args.work_with_tokens) elif args.aggregation == "spoc": return aggregation.SPoC() elif args.aggregation == "mac": return aggregation.MAC() elif args.aggregation == "rmac": return aggregation.RMAC() elif args.aggregation == "netvlad": return aggregation.NetVLAD(clusters_num=args.netvlad_clusters, dim=args.features_dim, work_with_tokens=args.work_with_tokens) elif args.aggregation == 'crn': return aggregation.CRN(clusters_num=args.netvlad_clusters, dim=args.features_dim) elif args.aggregation == "rrm": return aggregation.RRM(args.features_dim) elif args.aggregation in ['cls', 'seqpool']: return nn.Identity() def get_pretrained_model(args): if args.pretrain == 'places': num_classes = 365 elif args.pretrain == 'gldv2': num_classes = 512 if args.backbone.startswith("resnet18"): model = torchvision.models.resnet18(num_classes=num_classes) elif args.backbone.startswith("resnet50"): model = torchvision.models.resnet50(num_classes=num_classes) elif args.backbone.startswith("resnet101"): model = torchvision.models.resnet101(num_classes=num_classes) elif args.backbone.startswith("vgg16"): model = torchvision.models.vgg16(num_classes=num_classes) if args.backbone.startswith('resnet'): model_name = args.backbone.split('conv')[0] + "_" + args.pretrain else: model_name = args.backbone + "_" + args.pretrain file_path = join("data", "pretrained_nets", model_name +".pth") if not os.path.exists(file_path): gdd.download_file_from_google_drive(file_id=PRETRAINED_MODELS[model_name], dest_path=file_path) state_dict = torch.load(file_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict) return model def get_backbone(args): # The aggregation layer works differently based on the type of architecture args.work_with_tokens = args.backbone.startswith('cct') or args.backbone.startswith('vit') if args.backbone.startswith("resnet"): if args.pretrain in ['places', 'gldv2']: backbone = get_pretrained_model(args) elif args.backbone.startswith("resnet18"): backbone = torchvision.models.resnet18(pretrained=True) elif args.backbone.startswith("resnet50"): backbone = torchvision.models.resnet50(pretrained=True) elif args.backbone.startswith("resnet101"): backbone = torchvision.models.resnet101(pretrained=True) for name, child in backbone.named_children(): # Freeze layers before conv_3 if name == "layer3": break for params in child.parameters(): params.requires_grad = False if args.backbone.endswith("conv4"): logging.debug(f"Train only conv4_x of the resnet{args.backbone.split('conv')[0]} (remove conv5_x), freeze the previous ones") layers = list(backbone.children())[:-3] elif args.backbone.endswith("conv5"): logging.debug(f"Train only conv4_x and conv5_x of the resnet{args.backbone.split('conv')[0]}, freeze the previous ones") layers = list(backbone.children())[:-2] elif args.backbone == "vgg16": if args.pretrain in ['places', 'gldv2']: backbone = get_pretrained_model(args) else: backbone = torchvision.models.vgg16(pretrained=True) layers = list(backbone.features.children())[:-2] for l in layers[:-5]: for p in l.parameters(): p.requires_grad = False logging.debug("Train last layers of the vgg16, freeze the previous ones") elif args.backbone == "alexnet": backbone = torchvision.models.alexnet(pretrained=True) layers = list(backbone.features.children())[:-2] for l in layers[:5]: for p in l.parameters(): p.requires_grad = False logging.debug("Train last layers of the alexnet, freeze the previous ones") elif args.backbone.startswith("cct"): if args.backbone.startswith("cct384"): backbone = cct_14_7x2_384(pretrained=True, progress=True, aggregation=args.aggregation) if args.trunc_te: logging.debug(f"Truncate CCT at transformers encoder {args.trunc_te}") backbone.classifier.blocks = torch.nn.ModuleList(backbone.classifier.blocks[:args.trunc_te].children()) if args.freeze_te: logging.debug(f"Freeze all the layers up to tranformer encoder {args.freeze_te}") for p in backbone.parameters(): p.requires_grad = False for name, child in backbone.classifier.blocks.named_children(): if int(name) > args.freeze_te: for params in child.parameters(): params.requires_grad = True args.features_dim = 384 return backbone elif args.backbone.startswith("vit"): assert args.resize[0] in [224, 384], f'Image size for ViT must be either 224 or 384, but it\'s {args.resize[0]}' if args.resize[0] == 224: backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') elif args.resize[0] == 384: backbone = ViTModel.from_pretrained('google/vit-base-patch16-384') if args.trunc_te: logging.debug(f"Truncate ViT at transformers encoder {args.trunc_te}") backbone.encoder.layer = backbone.encoder.layer[:args.trunc_te] if args.freeze_te: logging.debug(f"Freeze all the layers up to tranformer encoder {args.freeze_te+1}") for p in backbone.parameters(): p.requires_grad = False for name, child in backbone.encoder.layer.named_children(): if int(name) > args.freeze_te: for params in child.parameters(): params.requires_grad = True backbone = VitWrapper(backbone, args.aggregation) args.features_dim = 768 return backbone backbone = torch.nn.Sequential(*layers) args.features_dim = get_output_channels_dim(backbone) # Dinamically obtain number of channels in output return backbone class VitWrapper(nn.Module): def __init__(self, vit_model, aggregation): super().__init__() self.vit_model = vit_model self.aggregation = aggregation def forward(self, x): if self.aggregation in ["netvlad", "gem"]: return self.vit_model(x).last_hidden_state[:, 1:, :] else: return self.vit_model(x).last_hidden_state[:, 0, :] def get_output_channels_dim(model): """Return the number of channels in the output of a model.""" return model(torch.ones([1, 3, 224, 224])).shape[1] ================================================ FILE: model/normalization.py ================================================ import torch.nn as nn import torch.nn.functional as F class L2Norm(nn.Module): def __init__(self, dim=1): super().__init__() self.dim = dim def forward(self, x): return F.normalize(x, p=2, dim=self.dim) ================================================ FILE: model/sync_batchnorm/__init__.py ================================================ # -*- coding: utf-8 -*- # File : __init__.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. from .batchnorm import set_sbn_eps_mode from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d from .batchnorm import patch_sync_batchnorm, convert_model from .replicate import DataParallelWithCallback, patch_replication_callback ================================================ FILE: model/sync_batchnorm/batchnorm.py ================================================ # -*- coding: utf-8 -*- # File : batchnorm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import collections import contextlib import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm try: from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast except ImportError: ReduceAddCoalesced = Broadcast = None try: from jactorch.parallel.comm import SyncMaster from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback except ImportError: from .comm import SyncMaster from .replicate import DataParallelWithCallback __all__ = [ 'set_sbn_eps_mode', 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 'patch_sync_batchnorm', 'convert_model' ] SBN_EPS_MODE = 'clamp' def set_sbn_eps_mode(mode): global SBN_EPS_MODE assert mode in ('clamp', 'plus') SBN_EPS_MODE = mode def _sum_ft(tensor): """sum over the first and last dimention""" return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): """add new dimensions at the front and the tail""" return tensor.unsqueeze(0).unsqueeze(-1) _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) if not self.track_running_stats: import warnings warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size if hasattr(torch, 'no_grad'): with torch.no_grad(): self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data else: self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data if SBN_EPS_MODE == 'clamp': return mean, bias_var.clamp(self.eps) ** -0.5 elif SBN_EPS_MODE == 'plus': return mean, (bias_var + self.eps) ** -0.5 else: raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a mini-batch. .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm1d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm Args: num_features: num_features from an expected input of size `batch_size x num_features [x width]` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape:: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch of 3d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm2d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape:: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm2d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm2d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch of 4d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm3d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm or Spatio-temporal BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x depth x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape:: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm3d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm3d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) @contextlib.contextmanager def patch_sync_batchnorm(): import torch.nn as nn backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d nn.BatchNorm1d = SynchronizedBatchNorm1d nn.BatchNorm2d = SynchronizedBatchNorm2d nn.BatchNorm3d = SynchronizedBatchNorm3d yield nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup def convert_model(module): """Traverse the input module and its child recursively and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d to SynchronizedBatchNorm*N*d Args: module: the input module needs to be convert to SyncBN model Examples: >>> import torch.nn as nn >>> import torchvision >>> # m is a standard pytorch model >>> m = torchvision.models.resnet18(True) >>> m = nn.DataParallel(m) >>> # after convert, m is using SyncBN >>> m = convert_model(m) """ if isinstance(module, torch.nn.DataParallel): mod = module.module mod = convert_model(mod) mod = DataParallelWithCallback(mod, device_ids=module.device_ids) return mod mod = module for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d], [SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d]): if isinstance(module, pth_module): mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) mod.running_mean = module.running_mean mod.running_var = module.running_var if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_model(child)) return mod ================================================ FILE: model/sync_batchnorm/batchnorm_reimpl.py ================================================ #! /usr/bin/env python3 # -*- coding: utf-8 -*- # File : batchnorm_reimpl.py # Author : acgtyrant # Date : 11/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import torch import torch.nn as nn import torch.nn.init as init __all__ = ['BatchNorm2dReimpl'] class BatchNorm2dReimpl(nn.Module): """ A re-implementation of batch normalization, used for testing the numerical stability. Author: acgtyrant See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 """ def __init__(self, num_features, eps=1e-5, momentum=0.1): super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.weight = nn.Parameter(torch.empty(num_features)) self.bias = nn.Parameter(torch.empty(num_features)) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_running_stats(self): self.running_mean.zero_() self.running_var.fill_(1) def reset_parameters(self): self.reset_running_stats() init.uniform_(self.weight) init.zeros_(self.bias) def forward(self, input_): batchsize, channels, height, width = input_.size() numel = batchsize * height * width input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) sum_ = input_.sum(1) sum_of_square = input_.pow(2).sum(1) mean = sum_ / numel sumvar = sum_of_square - sum_ * mean self.running_mean = ( (1 - self.momentum) * self.running_mean + self.momentum * mean.detach() ) unbias_var = sumvar / (numel - 1) self.running_var = ( (1 - self.momentum) * self.running_var + self.momentum * unbias_var.detach() ) bias_var = sumvar / numel inv_std = 1 / (bias_var + self.eps).pow(0.5) output = ( (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() ================================================ FILE: model/sync_batchnorm/comm.py ================================================ # -*- coding: utf-8 -*- # File : comm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import queue import collections import threading __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] class FutureResult(object): """A thread-safe future implementation. Used only as one-to-one pipe.""" def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): with self._lock: assert self._result is None, 'Previous result has\'t been fetched.' self._result = result self._cond.notify() def get(self): with self._lock: if self._result is None: self._cond.wait() res = self._result self._result = None return res _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication.""" def run_slave(self, msg): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """ def __init__(self, master_callback): """ Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False def __getstate__(self): return {'master_callback': self._master_callback} def __setstate__(self, state): self.__init__(state['master_callback']) def register_slave(self, identifier): """ Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): """ Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1] @property def nr_slaves(self): return len(self._registry) ================================================ FILE: model/sync_batchnorm/replicate.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: model/sync_batchnorm/unittest.py ================================================ # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import torch class TorchTestCase(unittest.TestCase): def assertTensorClose(self, x, y): adiff = float((x - y).abs().max()) if (y == 0).all(): rdiff = 'NaN' else: rdiff = float((adiff / y).abs().max()) message = ( 'Tensor close check failed\n' 'adiff={}\n' 'rdiff={}\n' ).format(adiff, rdiff) self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) ================================================ FILE: parser.py ================================================ import os import torch import argparse def parse_arguments(): parser = argparse.ArgumentParser(description="Benchmarking Visual Geolocalization", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Training parameters parser.add_argument("--train_batch_size", type=int, default=4, help="Number of triplets (query, pos, negs) in a batch. Each triplet consists of 12 images") parser.add_argument("--infer_batch_size", type=int, default=16, help="Batch size for inference (caching and testing)") parser.add_argument("--criterion", type=str, default='triplet', help='loss to be used', choices=["triplet", "sare_ind", "sare_joint"]) parser.add_argument("--margin", type=float, default=0.1, help="margin for the triplet loss") parser.add_argument("--epochs_num", type=int, default=1000, help="number of epochs to train for") parser.add_argument("--patience", type=int, default=3) parser.add_argument("--lr", type=float, default=0.00001, help="_") parser.add_argument("--lr_crn_layer", type=float, default=5e-3, help="Learning rate for the CRN layer") parser.add_argument("--lr_crn_net", type=float, default=5e-4, help="Learning rate to finetune pretrained network when using CRN") parser.add_argument("--optim", type=str, default="adam", help="_", choices=["adam", "sgd"]) parser.add_argument("--cache_refresh_rate", type=int, default=1000, help="How often to refresh cache, in number of queries") parser.add_argument("--queries_per_epoch", type=int, default=5000, help="How many queries to consider for one epoch. Must be multiple of cache_refresh_rate") parser.add_argument("--negs_num_per_query", type=int, default=10, help="How many negatives to consider per each query in the loss") parser.add_argument("--neg_samples_num", type=int, default=1000, help="How many negatives to use to compute the hardest ones") parser.add_argument("--mining", type=str, default="partial", choices=["partial", "full", "random", "msls_weighted"]) # Model parameters parser.add_argument("--backbone", type=str, default="resnet18conv4", choices=["alexnet", "vgg16", "resnet18conv4", "resnet18conv5", "resnet50conv4", "resnet50conv5", "resnet101conv4", "resnet101conv5", "cct384", "vit"], help="_") parser.add_argument("--l2", type=str, default="before_pool", choices=["before_pool", "after_pool", "none"], help="When (and if) to apply the l2 norm with shallow aggregation layers") parser.add_argument("--aggregation", type=str, default="netvlad", choices=["netvlad", "gem", "spoc", "mac", "rmac", "crn", "rrm", "cls", "seqpool"]) parser.add_argument('--netvlad_clusters', type=int, default=64, help="Number of clusters for NetVLAD layer.") parser.add_argument('--pca_dim', type=int, default=None, help="PCA dimension (number of principal components). If None, PCA is not used.") parser.add_argument('--fc_output_dim', type=int, default=None, help="Output dimension of fully connected layer. If None, don't use a fully connected layer.") parser.add_argument('--pretrain', type=str, default="imagenet", choices=['imagenet', 'gldv2', 'places'], help="Select the pretrained weights for the starting network") parser.add_argument("--off_the_shelf", type=str, default="imagenet", choices=["imagenet", "radenovic_sfm", "radenovic_gldv1", "naver"], help="Off-the-shelf networks from popular GitHub repos. Only with ResNet-50/101 + GeM + FC 2048") parser.add_argument("--trunc_te", type=int, default=None, choices=list(range(0, 14))) parser.add_argument("--freeze_te", type=int, default=None, choices=list(range(-1, 14))) # Initialization parameters parser.add_argument("--seed", type=int, default=0) parser.add_argument("--resume", type=str, default=None, help="Path to load checkpoint from, for resuming training or testing.") # Other parameters parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) parser.add_argument("--num_workers", type=int, default=8, help="num_workers for all dataloaders") parser.add_argument('--resize', type=int, default=[480, 640], nargs=2, help="Resizing shape for images (HxW).") parser.add_argument('--test_method', type=str, default="hard_resize", choices=["hard_resize", "single_query", "central_crop", "five_crops", "nearest_crop", "maj_voting"], help="This includes pre/post-processing methods and prediction refinement") parser.add_argument("--majority_weight", type=float, default=0.01, help="only for majority voting, scale factor, the higher it is the more importance is given to agreement") parser.add_argument("--efficient_ram_testing", action='store_true', help="_") parser.add_argument("--val_positive_dist_threshold", type=int, default=25, help="_") parser.add_argument("--train_positives_dist_threshold", type=int, default=10, help="_") parser.add_argument('--recall_values', type=int, default=[1, 5, 10, 20], nargs="+", help="Recalls to be computed, such as R@5.") # Data augmentation parameters parser.add_argument("--brightness", type=float, default=0, help="_") parser.add_argument("--contrast", type=float, default=0, help="_") parser.add_argument("--saturation", type=float, default=0, help="_") parser.add_argument("--hue", type=float, default=0, help="_") parser.add_argument("--rand_perspective", type=float, default=0, help="_") parser.add_argument("--horizontal_flip", action='store_true', help="_") parser.add_argument("--random_resized_crop", type=float, default=0, help="_") parser.add_argument("--random_rotation", type=float, default=0, help="_") # Paths parameters parser.add_argument("--datasets_folder", type=str, default=None, help="Path with all datasets") parser.add_argument("--dataset_name", type=str, default="pitts30k", help="Relative path of the dataset") parser.add_argument("--pca_dataset_folder", type=str, default=None, help="Path with images to be used to compute PCA (ie: pitts30k/images/train") parser.add_argument("--save_dir", type=str, default="default", help="Folder name of the current run (saved in ./logs/)") args = parser.parse_args() if args.datasets_folder is None: try: args.datasets_folder = os.environ['DATASETS_FOLDER'] except KeyError: raise Exception("You should set the parameter --datasets_folder or export " + "the DATASETS_FOLDER environment variable as such \n" + "export DATASETS_FOLDER=../datasets_vg/datasets") if args.aggregation == "crn" and args.resume is None: raise ValueError("CRN must be resumed from a trained NetVLAD checkpoint, but you set resume=None.") if args.queries_per_epoch % args.cache_refresh_rate != 0: raise ValueError("Ensure that queries_per_epoch is divisible by cache_refresh_rate, " + f"because {args.queries_per_epoch} is not divisible by {args.cache_refresh_rate}") if torch.cuda.device_count() >= 2 and args.criterion in ['sare_joint', "sare_ind"]: raise NotImplementedError("SARE losses are not implemented for multiple GPUs, " + f"but you're using {torch.cuda.device_count()} GPUs and {args.criterion} loss.") if args.mining == "msls_weighted" and args.dataset_name != "msls": raise ValueError("msls_weighted mining can only be applied to msls dataset, but you're using it on {args.dataset_name}") if args.off_the_shelf in ["radenovic_sfm", "radenovic_gldv1", "naver"]: if args.backbone not in ["resnet50conv5", "resnet101conv5"] or args.aggregation != "gem" or args.fc_output_dim != 2048: raise ValueError("Off-the-shelf models are trained only with ResNet-50/101 + GeM + FC 2048") if args.pca_dim is not None and args.pca_dataset_folder is None: raise ValueError("Please specify --pca_dataset_folder when using pca") if args.backbone == "vit": if args.resize != [224, 224] and args.resize != [384, 384]: raise ValueError(f'Image size for ViT must be either 224 or 384 {args.resize}') if args.backbone == "cct384": if args.resize != [384, 384]: raise ValueError(f'Image size for CCT384 must be 384, but it is {args.resize}') if args.backbone in ["alexnet", "vgg16", "resnet18conv4", "resnet18conv5", "resnet50conv4", "resnet50conv5", "resnet101conv4", "resnet101conv5"]: if args.aggregation in ["cls", "seqpool"]: raise ValueError(f"CNNs like {args.backbone} can't work with aggregation {args.aggregation}") if args.backbone in ["cct384"]: if args.aggregation in ["spoc", "mac", "rmac", "crn", "rrm"]: raise ValueError(f"CCT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls, seqpool]") if args.backbone == "vit": if args.aggregation not in ["cls", "gem", "netvlad"]: raise ValueError(f"ViT can't work with aggregation {args.aggregation}. Please use one among [netvlad, gem, cls]") return args ================================================ FILE: requirements.txt ================================================ numpy==1.19.4 torchvision==0.8.1 psutil==5.6.7 faiss_cpu==1.5.3 tqdm==4.48.2 torch==1.7.0 Pillow==8.2.0 scikit_learn==0.24.1 torchscan==0.1.1 googledrivedownloader==0.4 requests==2.26.0 timm==0.4.12 transformers==4.8.2 einops ================================================ FILE: test.py ================================================ import faiss import torch import logging import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.dataset import Subset def test_efficient_ram_usage(args, eval_ds, model, test_method="hard_resize"): """This function gives the same output as test(), but uses much less RAM. This can be useful when testing with large descriptors (e.g. NetVLAD) on large datasets (e.g. San Francisco). Obviously it is slower than test(), and can't be used with PCA. """ model = model.eval() if test_method == 'nearest_crop' or test_method == "maj_voting": distances = np.empty([eval_ds.queries_num * 5, eval_ds.database_num], dtype=np.float32) else: distances = np.empty([eval_ds.queries_num, eval_ds.database_num], dtype=np.float32) with torch.no_grad(): if test_method == 'nearest_crop' or test_method == 'maj_voting': queries_features = np.ones((eval_ds.queries_num * 5, args.features_dim), dtype="float32") else: queries_features = np.ones((eval_ds.queries_num, args.features_dim), dtype="float32") logging.debug("Extracting queries features for evaluation/testing") queries_infer_batch_size = 1 if test_method == "single_query" else args.infer_batch_size eval_ds.test_method = test_method queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num))) queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda")) for inputs, indices in tqdm(queries_dataloader, ncols=100): if test_method == "five_crops" or test_method == "nearest_crop" or test_method == 'maj_voting': inputs = torch.cat(tuple(inputs)) # shape = 5*bs x 3 x 480 x 480 features = model(inputs.to(args.device)) if test_method == "five_crops": # Compute mean along the 5 crops features = torch.stack(torch.split(features, 5)).mean(1) if test_method == "nearest_crop" or test_method == 'maj_voting': start_idx = (indices[0] - eval_ds.database_num) * 5 end_idx = start_idx + indices.shape[0] * 5 indices = np.arange(start_idx, end_idx) queries_features[indices, :] = features.cpu().numpy() else: queries_features[indices.numpy()-eval_ds.database_num, :] = features.cpu().numpy() queries_features = torch.tensor(queries_features).type(torch.float32).cuda() logging.debug("Extracting database features for evaluation/testing") # For database use "hard_resize", although it usually has no effect because database images have same resolution eval_ds.test_method = "hard_resize" database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) for inputs, indices in tqdm(database_dataloader, ncols=100): inputs = inputs.to(args.device) features = model(inputs) for pn, (index, pred_feature) in enumerate(zip(indices, features)): distances[:, index] = ((queries_features-pred_feature)**2).sum(1).cpu().numpy() del features, queries_features, pred_feature predictions = distances.argsort(axis=1)[:, :max(args.recall_values)] if test_method == 'nearest_crop': distances = np.array([distances[row, index] for row, index in enumerate(predictions)]) distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5)) predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5)) for q in range(eval_ds.queries_num): # sort predictions by distance sort_idx = np.argsort(distances[q]) predictions[q] = predictions[q, sort_idx] # remove duplicated predictions, i.e. keep only the closest ones _, unique_idx = np.unique(predictions[q], return_index=True) # unique_idx is sorted based on the unique values, sort it again predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20] predictions = predictions[:, :20] # keep only the closer 20 predictions for each elif test_method == 'maj_voting': distances = np.array([distances[row, index] for row, index in enumerate(predictions)]) distances = np.reshape(distances, (eval_ds.queries_num, 5, 20)) predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20)) for q in range(eval_ds.queries_num): # votings, modify distances in-place top_n_voting('top1', predictions[q], distances[q], args.majority_weight) top_n_voting('top5', predictions[q], distances[q], args.majority_weight) top_n_voting('top10', predictions[q], distances[q], args.majority_weight) # flatten dist and preds from 5, 20 -> 20*5 # and then proceed as usual to keep only first 20 dists = distances[q].flatten() preds = predictions[q].flatten() # sort predictions by distance sort_idx = np.argsort(dists) preds = preds[sort_idx] # remove duplicated predictions, i.e. keep only the closest ones _, unique_idx = np.unique(preds, return_index=True) # unique_idx is sorted based on the unique values, sort it again # here the row corresponding to the first crop is used as a # 'buffer' for each query, and in the end the dimension # relative to crops is eliminated predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20] predictions = predictions[:, 0, :20] # keep only the closer 20 predictions for each query del distances #### For each query, check if the predictions are correct positives_per_query = eval_ds.get_positives() # args.recall_values by default is [1, 5, 10, 20] recalls = np.zeros(len(args.recall_values)) for query_index, pred in enumerate(predictions): for i, n in enumerate(args.recall_values): if np.any(np.in1d(pred[:n], positives_per_query[query_index])): recalls[i:] += 1 break recalls = recalls / eval_ds.queries_num * 100 recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)]) return recalls, recalls_str def test(args, eval_ds, model, test_method="hard_resize", pca=None): """Compute features of the given dataset and compute the recalls.""" assert test_method in ["hard_resize", "single_query", "central_crop", "five_crops", "nearest_crop", "maj_voting"], f"test_method can't be {test_method}" if args.efficient_ram_testing: return test_efficient_ram_usage(args, eval_ds, model, test_method) model = model.eval() with torch.no_grad(): logging.debug("Extracting database features for evaluation/testing") # For database use "hard_resize", although it usually has no effect because database images have same resolution eval_ds.test_method = "hard_resize" database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num))) database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers, batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda")) if test_method == "nearest_crop" or test_method == 'maj_voting': all_features = np.empty((5 * eval_ds.queries_num + eval_ds.database_num, args.features_dim), dtype="float32") else: all_features = np.empty((len(eval_ds), args.features_dim), dtype="float32") for inputs, indices in tqdm(database_dataloader, ncols=100): features = model(inputs.to(args.device)) features = features.cpu().numpy() if pca is not None: features = pca.transform(features) all_features[indices.numpy(), :] = features logging.debug("Extracting queries features for evaluation/testing") queries_infer_batch_size = 1 if test_method == "single_query" else args.infer_batch_size eval_ds.test_method = test_method queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num))) queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda")) for inputs, indices in tqdm(queries_dataloader, ncols=100): if test_method == "five_crops" or test_method == "nearest_crop" or test_method == 'maj_voting': inputs = torch.cat(tuple(inputs)) # shape = 5*bs x 3 x 480 x 480 features = model(inputs.to(args.device)) if test_method == "five_crops": # Compute mean along the 5 crops features = torch.stack(torch.split(features, 5)).mean(1) features = features.cpu().numpy() if pca is not None: features = pca.transform(features) if test_method == "nearest_crop" or test_method == 'maj_voting': # store the features of all 5 crops start_idx = eval_ds.database_num + (indices[0] - eval_ds.database_num) * 5 end_idx = start_idx + indices.shape[0] * 5 indices = np.arange(start_idx, end_idx) all_features[indices, :] = features else: all_features[indices.numpy(), :] = features queries_features = all_features[eval_ds.database_num:] database_features = all_features[:eval_ds.database_num] faiss_index = faiss.IndexFlatL2(args.features_dim) faiss_index.add(database_features) del database_features, all_features logging.debug("Calculating recalls") distances, predictions = faiss_index.search(queries_features, max(args.recall_values)) if test_method == 'nearest_crop': distances = np.reshape(distances, (eval_ds.queries_num, 20 * 5)) predictions = np.reshape(predictions, (eval_ds.queries_num, 20 * 5)) for q in range(eval_ds.queries_num): # sort predictions by distance sort_idx = np.argsort(distances[q]) predictions[q] = predictions[q, sort_idx] # remove duplicated predictions, i.e. keep only the closest ones _, unique_idx = np.unique(predictions[q], return_index=True) # unique_idx is sorted based on the unique values, sort it again predictions[q, :20] = predictions[q, np.sort(unique_idx)][:20] predictions = predictions[:, :20] # keep only the closer 20 predictions for each query elif test_method == 'maj_voting': distances = np.reshape(distances, (eval_ds.queries_num, 5, 20)) predictions = np.reshape(predictions, (eval_ds.queries_num, 5, 20)) for q in range(eval_ds.queries_num): # votings, modify distances in-place top_n_voting('top1', predictions[q], distances[q], args.majority_weight) top_n_voting('top5', predictions[q], distances[q], args.majority_weight) top_n_voting('top10', predictions[q], distances[q], args.majority_weight) # flatten dist and preds from 5, 20 -> 20*5 # and then proceed as usual to keep only first 20 dists = distances[q].flatten() preds = predictions[q].flatten() # sort predictions by distance sort_idx = np.argsort(dists) preds = preds[sort_idx] # remove duplicated predictions, i.e. keep only the closest ones _, unique_idx = np.unique(preds, return_index=True) # unique_idx is sorted based on the unique values, sort it again # here the row corresponding to the first crop is used as a # 'buffer' for each query, and in the end the dimension # relative to crops is eliminated predictions[q, 0, :20] = preds[np.sort(unique_idx)][:20] predictions = predictions[:, 0, :20] # keep only the closer 20 predictions for each query #### For each query, check if the predictions are correct positives_per_query = eval_ds.get_positives() # args.recall_values by default is [1, 5, 10, 20] recalls = np.zeros(len(args.recall_values)) for query_index, pred in enumerate(predictions): for i, n in enumerate(args.recall_values): if np.any(np.in1d(pred[:n], positives_per_query[query_index])): recalls[i:] += 1 break # Divide by the number of queries*100, so the recalls are in percentages recalls = recalls / eval_ds.queries_num * 100 recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)]) return recalls, recalls_str def top_n_voting(topn, predictions, distances, maj_weight): if topn == 'top1': n = 1 selected = 0 elif topn == 'top5': n = 5 selected = slice(0, 5) elif topn == 'top10': n = 10 selected = slice(0, 10) # find predictions that repeat in the first, first five, # or fist ten columns for each crop vals, counts = np.unique(predictions[:, selected], return_counts=True) # for each prediction that repeats more than once, # subtract from its score for val, count in zip(vals[counts > 1], counts[counts > 1]): mask = (predictions[:, selected] == val) distances[:, selected][mask] -= maj_weight * count/n ================================================ FILE: train.py ================================================ import math import torch import logging import numpy as np from tqdm import tqdm import torch.nn as nn import multiprocessing from os.path import join from datetime import datetime import torchvision.transforms as transforms from torch.utils.data.dataloader import DataLoader import util import test import parser import commons import datasets_ws from model import network from model.sync_batchnorm import convert_model from model.functional import sare_ind, sare_joint torch.backends.cudnn.benchmark = True # Provides a speedup #### Initial setup: parser, logging... args = parser.parse_arguments() start_time = datetime.now() args.save_dir = join("logs", args.save_dir, start_time.strftime('%Y-%m-%d_%H-%M-%S')) commons.setup_logging(args.save_dir) commons.make_deterministic(args.seed) logging.info(f"Arguments: {args}") logging.info(f"The outputs are being saved in {args.save_dir}") logging.info(f"Using {torch.cuda.device_count()} GPUs and {multiprocessing.cpu_count()} CPUs") #### Creation of Datasets logging.debug(f"Loading dataset {args.dataset_name} from folder {args.datasets_folder}") triplets_ds = datasets_ws.TripletsDataset(args, args.datasets_folder, args.dataset_name, "train", args.negs_num_per_query) logging.info(f"Train query set: {triplets_ds}") val_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "val") logging.info(f"Val set: {val_ds}") test_ds = datasets_ws.BaseDataset(args, args.datasets_folder, args.dataset_name, "test") logging.info(f"Test set: {test_ds}") #### Initialize model model = network.GeoLocalizationNet(args) model = model.to(args.device) if args.aggregation in ["netvlad", "crn"]: # If using NetVLAD layer, initialize it if not args.resume: triplets_ds.is_inference = True model.aggregation.initialize_netvlad_layer(args, triplets_ds, model.backbone) args.features_dim *= args.netvlad_clusters model = torch.nn.DataParallel(model) #### Setup Optimizer and Loss if args.aggregation == "crn": crn_params = list(model.module.aggregation.crn.parameters()) net_params = list(model.module.backbone.parameters()) + \ list([m[1] for m in model.module.aggregation.named_parameters() if not m[0].startswith('crn')]) if args.optim == "adam": optimizer = torch.optim.Adam([{'params': crn_params, 'lr': args.lr_crn_layer}, {'params': net_params, 'lr': args.lr_crn_net}]) logging.info("You're using CRN with Adam, it is advised to use SGD") elif args.optim == "sgd": optimizer = torch.optim.SGD([{'params': crn_params, 'lr': args.lr_crn_layer, 'momentum': 0.9, 'weight_decay': 0.001}, {'params': net_params, 'lr': args.lr_crn_net, 'momentum': 0.9, 'weight_decay': 0.001}]) else: if args.optim == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) elif args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001) if args.criterion == "triplet": criterion_triplet = nn.TripletMarginLoss(margin=args.margin, p=2, reduction="sum") elif args.criterion == "sare_ind": criterion_triplet = sare_ind elif args.criterion == "sare_joint": criterion_triplet = sare_joint #### Resume model, optimizer, and other training parameters if args.resume: if args.aggregation != 'crn': model, optimizer, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, optimizer) else: # CRN uses pretrained NetVLAD, then requires loading with strict=False and # does not load the optimizer from the checkpoint file. model, _, best_r5, start_epoch_num, not_improved_num = util.resume_train(args, model, strict=False) logging.info(f"Resuming from epoch {start_epoch_num} with best recall@5 {best_r5:.1f}") else: best_r5 = start_epoch_num = not_improved_num = 0 if args.backbone.startswith('vit'): logging.info(f"Output dimension of the model is {args.features_dim}") else: logging.info(f"Output dimension of the model is {args.features_dim}, with {util.get_flops(model, args.resize)}") if torch.cuda.device_count() >= 2: # When using more than 1GPU, use sync_batchnorm for torch.nn.DataParallel model = convert_model(model) model = model.cuda() #### Training loop for epoch_num in range(start_epoch_num, args.epochs_num): logging.info(f"Start training epoch: {epoch_num:02d}") epoch_start_time = datetime.now() epoch_losses = np.zeros((0, 1), dtype=np.float32) # How many loops should an epoch last (default is 5000/1000=5) loops_num = math.ceil(args.queries_per_epoch / args.cache_refresh_rate) for loop_num in range(loops_num): logging.debug(f"Cache: {loop_num} / {loops_num}") # Compute triplets to use in the triplet loss triplets_ds.is_inference = True triplets_ds.compute_triplets(args, model) triplets_ds.is_inference = False triplets_dl = DataLoader(dataset=triplets_ds, num_workers=args.num_workers, batch_size=args.train_batch_size, collate_fn=datasets_ws.collate_fn, pin_memory=(args.device == "cuda"), drop_last=True) model = model.train() # images shape: (train_batch_size*12)*3*H*W ; by default train_batch_size=4, H=480, W=640 # triplets_local_indexes shape: (train_batch_size*10)*3 ; because 10 triplets per query for images, triplets_local_indexes, _ in tqdm(triplets_dl, ncols=100): # Flip all triplets or none if args.horizontal_flip: images = transforms.RandomHorizontalFlip()(images) # Compute features of all images (images contains queries, positives and negatives) features = model(images.to(args.device)) loss_triplet = 0 if args.criterion == "triplet": triplets_local_indexes = torch.transpose( triplets_local_indexes.view(args.train_batch_size, args.negs_num_per_query, 3), 1, 0) for triplets in triplets_local_indexes: queries_indexes, positives_indexes, negatives_indexes = triplets.T loss_triplet += criterion_triplet(features[queries_indexes], features[positives_indexes], features[negatives_indexes]) elif args.criterion == 'sare_joint': # sare_joint needs to receive all the negatives at once triplet_index_batch = triplets_local_indexes.view(args.train_batch_size, 10, 3) for batch_triplet_index in triplet_index_batch: q = features[batch_triplet_index[0, 0]].unsqueeze(0) # obtain query as tensor of shape 1xn_features p = features[batch_triplet_index[0, 1]].unsqueeze(0) # obtain positive as tensor of shape 1xn_features n = features[batch_triplet_index[:, 2]] # obtain negatives as tensor of shape 10xn_features loss_triplet += criterion_triplet(q, p, n) elif args.criterion == "sare_ind": for triplet in triplets_local_indexes: # triplet is a 1-D tensor with the 3 scalars indexes of the triplet q_i, p_i, n_i = triplet loss_triplet += criterion_triplet(features[q_i:q_i+1], features[p_i:p_i+1], features[n_i:n_i+1]) del features loss_triplet /= (args.train_batch_size * args.negs_num_per_query) optimizer.zero_grad() loss_triplet.backward() optimizer.step() # Keep track of all losses by appending them to epoch_losses batch_loss = loss_triplet.item() epoch_losses = np.append(epoch_losses, batch_loss) del loss_triplet logging.debug(f"Epoch[{epoch_num:02d}]({loop_num}/{loops_num}): " + f"current batch triplet loss = {batch_loss:.4f}, " + f"average epoch triplet loss = {epoch_losses.mean():.4f}") logging.info(f"Finished epoch {epoch_num:02d} in {str(datetime.now() - epoch_start_time)[:-7]}, " f"average epoch triplet loss = {epoch_losses.mean():.4f}") # Compute recalls on validation set recalls, recalls_str = test.test(args, val_ds, model) logging.info(f"Recalls on val set {val_ds}: {recalls_str}") is_best = recalls[1] > best_r5 # Save checkpoint, which contains all training parameters util.save_checkpoint(args, { "epoch_num": epoch_num, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "recalls": recalls, "best_r5": best_r5, "not_improved_num": not_improved_num }, is_best, filename="last_model.pth") # If recall@5 did not improve for "many" epochs, stop training if is_best: logging.info(f"Improved: previous best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}") best_r5 = recalls[1] not_improved_num = 0 else: not_improved_num += 1 logging.info(f"Not improved: {not_improved_num} / {args.patience}: best R@5 = {best_r5:.1f}, current R@5 = {recalls[1]:.1f}") if not_improved_num >= args.patience: logging.info(f"Performance did not improve for {not_improved_num} epochs. Stop training.") break logging.info(f"Best R@5: {best_r5:.1f}") logging.info(f"Trained for {epoch_num+1:02d} epochs, in total in {str(datetime.now() - start_time)[:-7]}") #### Test best model on test set best_model_state_dict = torch.load(join(args.save_dir, "best_model.pth"))["model_state_dict"] model.load_state_dict(best_model_state_dict) recalls, recalls_str = test.test(args, test_ds, model, test_method=args.test_method) logging.info(f"Recalls on {test_ds}: {recalls_str}") ================================================ FILE: util.py ================================================ import re import torch import shutil import logging import torchscan import numpy as np from collections import OrderedDict from os.path import join from sklearn.decomposition import PCA import datasets_ws def get_flops(model, input_shape=(480, 640)): """Return the FLOPs as a string, such as '22.33 GFLOPs'""" assert len(input_shape) == 2, f"input_shape should have len==2, but it's {input_shape}" module_info = torchscan.crawl_module(model, (3, input_shape[0], input_shape[1])) output = torchscan.utils.format_info(module_info) return re.findall("Floating Point Operations on forward: (.*)\n", output)[0] def save_checkpoint(args, state, is_best, filename): model_path = join(args.save_dir, filename) torch.save(state, model_path) if is_best: shutil.copyfile(model_path, join(args.save_dir, "best_model.pth")) def resume_model(args, model): checkpoint = torch.load(args.resume, map_location=args.device) if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: # The pre-trained models that we provide in the README do not have 'state_dict' in the keys as # the checkpoint is directly the state dict state_dict = checkpoint # if the model contains the prefix "module" which is appendend by # DataParallel, remove it to avoid errors when loading dict if list(state_dict.keys())[0].startswith('module'): state_dict = OrderedDict({k.replace('module.', ''): v for (k, v) in state_dict.items()}) model.load_state_dict(state_dict) return model def resume_train(args, model, optimizer=None, strict=False): """Load model, optimizer, and other training parameters""" logging.debug(f"Loading checkpoint: {args.resume}") checkpoint = torch.load(args.resume) start_epoch_num = checkpoint["epoch_num"] model.load_state_dict(checkpoint["model_state_dict"], strict=strict) if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) best_r5 = checkpoint["best_r5"] not_improved_num = checkpoint["not_improved_num"] logging.debug(f"Loaded checkpoint: start_epoch_num = {start_epoch_num}, " f"current_best_R@5 = {best_r5:.1f}") if args.resume.endswith("last_model.pth"): # Copy best model to current save_dir shutil.copy(args.resume.replace("last_model.pth", "best_model.pth"), args.save_dir) return model, optimizer, best_r5, start_epoch_num, not_improved_num def compute_pca(args, model, pca_dataset_folder, full_features_dim): model = model.eval() pca_ds = datasets_ws.PCADataset(args, args.datasets_folder, pca_dataset_folder) dl = torch.utils.data.DataLoader(pca_ds, args.infer_batch_size, shuffle=True) pca_features = np.empty([min(len(pca_ds), 2**14), full_features_dim]) with torch.no_grad(): for i, images in enumerate(dl): if i*args.infer_batch_size >= len(pca_features): break features = model(images).cpu().numpy() pca_features[i*args.infer_batch_size : (i*args.infer_batch_size)+len(features)] = features pca = PCA(args.pca_dim) pca.fit(pca_features) return pca