Repository: qianqianwang68/caps Branch: master Commit: 1cb601a2b77f Files: 20 Total size: 1.6 MB Directory structure: gitextract_bzvz72tp/ ├── .gitignore ├── CAPS/ │ ├── __init__.py │ ├── caps_model.py │ ├── criterion.py │ └── network.py ├── LICENSE ├── README.md ├── config.py ├── configs/ │ ├── extract_features_hpatches.yaml │ └── train_megadepth.yaml ├── dataloader/ │ ├── __init__.py │ ├── data_utils.py │ └── megadepth.py ├── environment.yml ├── extract_features.py ├── jupyter/ │ ├── functions.py │ └── visualization.ipynb ├── test/ │ └── eval_pose_megadepth.py ├── train.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ CAPS/__pycache__/ jupyter/.ipynb_checkpoints/ jupyter/__pycache__/ logs/ out/ models/ ================================================ FILE: CAPS/__init__.py ================================================ ================================================ FILE: CAPS/caps_model.py ================================================ import os import cv2 import numpy as np import torch from torch import optim from torch.autograd import Variable import utils import torchvision.utils as vutils from CAPS.criterion import CtoFCriterion from CAPS.network import CAPSNet class CAPSModel(): def name(self): return 'CAPS Model' def __init__(self, args): self.args = args self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # init model, optimizer, scheduler self.model = CAPSNet(args, self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=args.lrate_decay_steps, gamma=args.lrate_decay_factor) # reloading from checkpoints self.start_step = self.load_from_ckpt() # define loss function self.criterion = CtoFCriterion(args).to(self.device) def set_input(self, data): self.im1 = Variable(data['im1'].to(self.device)) self.im2 = Variable(data['im2'].to(self.device)) self.coord1 = Variable(data['coord1'].to(self.device)) self.fmatrix = data['F'].cuda() self.pose = Variable(data['pose'].to(self.device)) self.intrinsic1 = data['intrinsic1'].to(self.device) self.intrinsic2 = data['intrinsic2'].to(self.device) self.im1_ori = data['im1_ori'] self.im2_ori = data['im2_ori'] self.batch_size = len(self.im1) self.imsize = self.im1.size()[2:] def forward(self): self.out = self.model.forward(self.im1, self.im2, self.coord1) def backward_net(self): loss = self.criterion(self.coord1, self.out, self.fmatrix, self.pose, self.imsize) self.j_loss, self.eloss_c, self.eloss_f, self.closs_c, self.closs_f, self.std_loss = loss self.j_loss.backward() def optimize_parameters(self): self.optimizer.zero_grad() self.forward() self.backward_net() self.optimizer.step() self.scheduler.step() def test(self): self.model.eval() with torch.no_grad(): coord2_e, std = self.model.test(self.im1, self.im2, self.coord1) return coord2_e, std def extract_features(self, im, coord): self.model.eval() with torch.no_grad(): feat_c, feat_f = self.model.extract_features(im, coord) return feat_c, feat_f def write_summary(self, writer, n_iter): print("%s | Step: %d, Loss: %2.5f" % (self.args.exp_name, n_iter, self.j_loss.item())) # write scalar if n_iter % self.args.log_scalar_interval == 0: writer.add_scalar('Total_loss', self.j_loss.item(), n_iter) writer.add_scalar('epipolar_loss_coarse', self.eloss_c.item(), n_iter) writer.add_scalar('epipolar_loss_fine', self.eloss_f.item(), n_iter) writer.add_scalar('cycle_loss_coarse', self.closs_c.item(), n_iter) writer.add_scalar('cycle_loss_fine', self.closs_f.item(), n_iter) # write image if n_iter % self.args.log_img_interval == 0: # this visualization shows a number of query points in the first image, # and their predicted correspondences in the second image, # the groundtruth epipolar lines for the query points are plotted in the second image num_kpts_display = 20 im1_o = self.im1_ori[0].numpy() im2_o = self.im2_ori[0].numpy() kpt1 = self.coord1.cpu().numpy()[0][:num_kpts_display, :] # predicted correspondence correspondence = self.out['coord2_ef'] kpt2 = correspondence.detach().cpu().numpy()[0][:num_kpts_display, :] lines2 = cv2.computeCorrespondEpilines(kpt1.reshape(-1, 1, 2), 1, self.fmatrix[0].cpu().numpy()) lines2 = lines2.reshape(-1, 3) im2_o, im1_o = utils.drawlines(im2_o, im1_o, lines2, kpt2, kpt1) vis = np.concatenate((im1_o, im2_o), 1) vis = torch.from_numpy(vis.transpose(2, 0, 1)).float().unsqueeze(0) x = vutils.make_grid(vis, normalize=True) writer.add_image('Image', x, n_iter) def load_model(self, filename): to_load = torch.load(filename) self.model.load_state_dict(to_load['state_dict']) if 'optimizer' in to_load.keys(): self.optimizer.load_state_dict(to_load['optimizer']) if 'scheduler' in to_load.keys(): self.scheduler.load_state_dict(to_load['scheduler']) return to_load['step'] def load_from_ckpt(self): ''' load model from existing checkpoints and return the current step :param ckpt_dir: the directory that stores ckpts :return: the current starting step ''' # load from the specified ckpt path if self.args.ckpt_path != "": print("Reloading from {}".format(self.args.ckpt_path)) if os.path.isfile(self.args.ckpt_path): step = self.load_model(self.args.ckpt_path) else: raise Exception('no checkpoint found in the following path:{}'.format(self.args.ckpt_path)) else: ckpt_folder = os.path.join(self.args.outdir, self.args.exp_name) os.makedirs(ckpt_folder, exist_ok=True) # load from the most recent ckpt from all existing ckpts ckpts = [os.path.join(ckpt_folder, f) for f in sorted(os.listdir(ckpt_folder)) if f.endswith('.pth')] if len(ckpts) > 0: fpath = ckpts[-1] step = self.load_model(fpath) print('Reloading from {}, starting at step={}'.format(fpath, step)) else: print('No ckpts found, training from scratch...') step = 0 return step def save_model(self, step): ckpt_folder = os.path.join(self.args.outdir, self.args.exp_name) os.makedirs(ckpt_folder, exist_ok=True) save_path = os.path.join(ckpt_folder, "{:06d}.pth".format(step)) print('saving ckpts {}...'.format(save_path)) torch.save({'step': step, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), }, save_path) ================================================ FILE: CAPS/criterion.py ================================================ import torch import torch.nn as nn class CtoFCriterion(nn.Module): def __init__(self, args): super(CtoFCriterion, self).__init__() self.args = args self.w_ec = args.w_epipolar_coarse self.w_ef = args.w_epipolar_fine self.w_cc = args.w_cycle_coarse self.w_cf = args.w_cycle_coarse self.w_std = args.w_std def homogenize(self, coord): coord = torch.cat((coord, torch.ones_like(coord[:, :, [0]])), -1) return coord def set_weight(self, std, mask=None, regularizer=0.0): if self.args.std: inverse_std = 1. / torch.clamp(std+regularizer, min=1e-10) weight = inverse_std / torch.mean(inverse_std) weight = weight.detach() # Bxn else: weight = torch.ones_like(std) if mask is not None: weight *= mask.float() weight /= (torch.mean(weight) + 1e-8) return weight def epipolar_cost(self, coord1, coord2, fmatrix): coord1_h = self.homogenize(coord1).transpose(1, 2) coord2_h = self.homogenize(coord2).transpose(1, 2) epipolar_line = fmatrix.bmm(coord1_h) # Bx3xn epipolar_line_ = epipolar_line / torch.clamp(torch.norm(epipolar_line[:, :2, :], dim=1, keepdim=True), min=1e-8) essential_cost = torch.abs(torch.sum(coord2_h * epipolar_line_, dim=1)) # Bxn return essential_cost def epipolar_loss(self, coord1, coord2, fmatrix, weight): essential_cost = self.epipolar_cost(coord1, coord2, fmatrix) loss = torch.mean(weight * essential_cost) return loss def cycle_consistency_loss(self, coord1, coord1_loop, weight, th=40): ''' compute the cycle consistency loss :param coord1: [batch_size, n_pts, 2] :param coord1_loop: the predicted location [batch_size, n_pts, 2] :param weight: the weight [batch_size, n_pts] :param th: the threshold, only consider distances under this threshold :return: the cycle consistency loss value ''' distance = torch.norm(coord1 - coord1_loop, dim=-1) distance_ = torch.zeros_like(distance) distance_[distance < th] = distance[distance < th] loss = torch.mean(weight * distance_) return loss def forward(self, coord1, data, fmatrix, pose, im_size): coord2_ec = data['coord2_ec'] coord2_ef = data['coord2_ef'] coord1_lc = data['coord1_lc'] coord1_lf = data['coord1_lf'] std_c = data['std_c'] std_f = data['std_f'] std_lc = data['std_lc'] std_lf = data['std_lf'] shorter_edge, longer_edge = min(im_size), max(im_size) epipolar_cost_c = self.epipolar_cost(coord1, coord2_ec, fmatrix) # only add fine level loss if the coarse level prediction is close enough to gt epipolar line mask_ctof = (epipolar_cost_c < (shorter_edge * self.args.window_size)) # only add cycle consistency loss if the coarse level prediction is close enough to gt epipolar line mask_epip_c = (epipolar_cost_c < (shorter_edge * self.args.th_epipolar)) mask_cycle_c = (epipolar_cost_c < (shorter_edge * self.args.th_cycle)) epipolar_cost_f = self.epipolar_cost(coord1, coord2_ef, fmatrix) # only add cycle consistency loss if the fine level prediction is close enough to gt epipolar line mask_epip_f = (epipolar_cost_f < (shorter_edge * self.args.th_epipolar)) mask_cycle_f = (epipolar_cost_f < shorter_edge * self.args.th_cycle) weight_c = self.set_weight(std_c, mask=mask_epip_c) weight_f = self.set_weight(std_f, mask=mask_epip_f*mask_ctof) eloss_c = torch.mean(epipolar_cost_c * weight_c) / longer_edge eloss_f = torch.mean(epipolar_cost_f * weight_f) / longer_edge weight_cycle_c = self.set_weight(std_c * std_lc, mask=mask_cycle_c) weight_cycle_f = self.set_weight(std_f * std_lf, mask=mask_cycle_f) closs_c = self.cycle_consistency_loss(coord1, coord1_lc, weight_cycle_c) / longer_edge closs_f = self.cycle_consistency_loss(coord1, coord1_lf, weight_cycle_f) / longer_edge loss = self.w_ec * eloss_c + self.w_ef * eloss_f + self.w_cc * closs_c + self.w_cf * closs_f std_loss = torch.mean(std_c) + torch.mean(std_f) loss += self.w_std * std_loss return loss, eloss_c, eloss_f, closs_c, closs_f, std_loss ================================================ FILE: CAPS/network.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import importlib class CAPSNet(nn.Module): def __init__(self, args, device): super(CAPSNet, self).__init__() self.args = args self.device = device self.net = ResUNet(pretrained=args.pretrained, encoder=args.backbone, coarse_out_ch=args.coarse_feat_dim, fine_out_ch=args.fine_feat_dim).to(self.device) @staticmethod def normalize(coord, h, w): ''' turn the coordinates from pixel indices to the range of [-1, 1] :param coord: [..., 2] :param h: the image height :param w: the image width :return: the normalized coordinates [..., 2] ''' c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).to(coord.device).float() coord_norm = (coord - c) / c return coord_norm @staticmethod def denormalize(coord_norm, h, w): ''' turn the coordinates from normalized value ([-1, 1]) to actual pixel indices :param coord_norm: [..., 2] :param h: the image height :param w: the image width :return: actual pixel coordinates ''' c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).to(coord_norm.device) coord = coord_norm * c + c return coord def ind2coord(self, ind, width): ind = ind.unsqueeze(-1) x = ind % width y = ind // width coord = torch.cat((x, y), -1).float() return coord def gen_grid(self, h_min, h_max, w_min, w_max, len_h, len_w): x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w), torch.linspace(h_min, h_max, len_h)]) grid = torch.stack((x, y), -1).transpose(0, 1).reshape(-1, 2).float().to(self.device) return grid def sample_feat_by_coord(self, x, coord_n, norm=False): ''' sample from normalized coordinates :param x: feature map [batch_size, n_dim, h, w] :param coord_n: normalized coordinates, [batch_size, n_pts, 2] :param norm: if l2 normalize features :return: the extracted features, [batch_size, n_pts, n_dim] ''' feat = F.grid_sample(x, coord_n.unsqueeze(2)).squeeze(-1) if norm: feat = F.normalize(feat) feat = feat.transpose(1, 2) return feat def compute_prob(self, feat1, feat2): ''' compute probability :param feat1: query features, [batch_size, m, n_dim] :param feat2: reference features, [batch_size, n, n_dim] :return: probability, [batch_size, m, n] ''' assert self.args.prob_from in ['correlation', 'distance'] if self.args.prob_from == 'correlation': sim = feat1.bmm(feat2.transpose(1, 2)) prob = F.softmax(sim, dim=-1) # Bxmxn else: dist = torch.sum(feat1**2, dim=-1, keepdim=True) + \ torch.sum(feat2**2, dim=-1, keepdim=True).transpose(1, 2) - \ 2 * feat1.bmm(feat2.transpose(1, 2)) prob = F.softmax(-dist, dim=-1) # Bxmxn return prob def get_1nn_coord(self, feat1, featmap2): ''' find the coordinates of nearest neighbor match :param feat1: query features, [batch_size, n_pts, n_dim] :param featmap2: the feature maps of the other image :return: normalized correspondence locations [batch_size, n_pts, 2] ''' batch_size, d, h, w = featmap2.shape feat2_flatten = featmap2.reshape(batch_size, d, h*w).transpose(1, 2) # Bx(hw)xd assert self.args.prob_from in ['correlation', 'distance'] if self.args.prob_from == 'correlation': sim = feat1.bmm(feat2_flatten.transpose(1, 2)) ind2_1nn = torch.max(sim, dim=-1)[1] else: dist = torch.sum(feat1**2, dim=-1, keepdim=True) + \ torch.sum(feat2_flatten**2, dim=-1, keepdim=True).transpose(1, 2) - \ 2 * feat1.bmm(feat2_flatten.transpose(1, 2)) ind2_1nn = torch.min(dist, dim=-1)[1] coord2 = self.ind2coord(ind2_1nn, w) coord2_n = self.normalize(coord2, h, w) return coord2_n def get_expected_correspondence_locs(self, feat1, featmap2, with_std=False): ''' compute the expected correspondence locations :param feat1: the feature vectors of query points [batch_size, n_pts, n_dim] :param featmap2: the feature maps of the reference image [batch_size, n_dim, h, w] :param with_std: if return the standard deviation :return: the normalized expected correspondence locations [batch_size, n_pts, 2] ''' B, d, h2, w2 = featmap2.size() grid_n = self.gen_grid(-1, 1, -1, 1, h2, w2) featmap2_flatten = featmap2.reshape(B, d, h2*w2).transpose(1, 2) # BX(hw)xd prob = self.compute_prob(feat1, featmap2_flatten) # Bxnx(hw) grid_n = grid_n.unsqueeze(0).unsqueeze(0) # 1x1x(hw)x2 expected_coord_n = torch.sum(grid_n * prob.unsqueeze(-1), dim=2) # Bxnx2 if with_std: # convert to normalized scale [-1, 1] var = torch.sum(grid_n**2 * prob.unsqueeze(-1), dim=2) - expected_coord_n**2 # Bxnx2 std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # Bxn return expected_coord_n, std else: return expected_coord_n def get_expected_correspondence_within_window(self, feat1, featmap2, coord2_n, with_std=False): ''' :param feat1: the feature vectors of query points [batch_size, n_pts, n_dim] :param featmap2: the feature maps of the reference image [batch_size, n_dim, h, w] :param coord2_n: normalized center locations [batch_size, n_pts, 2] :param with_std: if return the standard deviation :return: the normalized expected correspondence locations, [batch_size, n_pts, 2], optionally with std ''' batch_size, n_dim, h2, w2 = featmap2.shape n_pts = coord2_n.shape[1] grid_n = self.gen_grid(h_min=-self.args.window_size, h_max=self.args.window_size, w_min=-self.args.window_size, w_max=self.args.window_size, len_h=int(self.args.window_size*h2), len_w=int(self.args.window_size*w2)) grid_n_ = grid_n.repeat(batch_size, 1, 1, 1) # Bx1xhwx2 coord2_n_grid = coord2_n.unsqueeze(-2) + grid_n_ # Bxnxhwx2 feat2_win = F.grid_sample(featmap2, coord2_n_grid, padding_mode='zeros').permute(0, 2, 3, 1) # Bxnxhwxd feat1 = feat1.unsqueeze(-2) prob = self.compute_prob(feat1.reshape(batch_size*n_pts, -1, n_dim), feat2_win.reshape(batch_size*n_pts, -1, n_dim)).reshape(batch_size, n_pts, -1) expected_coord2_n = torch.sum(coord2_n_grid * prob.unsqueeze(-1), dim=2) # Bxnx2 if with_std: var = torch.sum(coord2_n_grid**2 * prob.unsqueeze(-1), dim=2) - expected_coord2_n**2 # Bxnx2 std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # Bxn return expected_coord2_n, std else: return expected_coord2_n def forward(self, im1, im2, coord1): # extract features for both images xc1, xf1 = self.net(im1) xc2, xf2 = self.net(im2) # image width and height h1i, w1i = im1.size()[2:] h2i, w2i = im2.size()[2:] coord1_n = self.normalize(coord1, h1i, w1i) feat1_coarse = self.sample_feat_by_coord(xc1, coord1_n) # Bxnxd coord2_ec_n, std_c = self.get_expected_correspondence_locs(feat1_coarse, xc2, with_std=True) # the center locations of the local window for fine level computation coord2_ec_n_ = self.get_1nn_coord(feat1_coarse, xc2) if self.args.use_nn else coord2_ec_n feat1_fine = self.sample_feat_by_coord(xf1, coord1_n) # Bxnxd coord2_ef_n, std_f = self.get_expected_correspondence_within_window(feat1_fine, xf2, coord2_ec_n_, with_std=True) feat2_coarse = self.sample_feat_by_coord(xc2, coord2_ec_n_) coord1_lc_n, std_lc = self.get_expected_correspondence_locs(feat2_coarse, xc1, with_std=True) feat2_fine = self.sample_feat_by_coord(xf2, coord2_ef_n) # Bxnxd coord1_lf_n, std_lf = self.get_expected_correspondence_within_window(feat2_fine, xf1, coord1_n, with_std=True) coord2_ec = self.denormalize(coord2_ec_n, h2i, w2i) coord2_ef = self.denormalize(coord2_ef_n, h2i, w2i) coord1_lc = self.denormalize(coord1_lc_n, h1i, w1i) coord1_lf = self.denormalize(coord1_lf_n, h1i, w1i) return {'coord2_ec': coord2_ec, 'coord2_ef': coord2_ef, 'coord1_lc': coord1_lc, 'coord1_lf': coord1_lf, 'std_c': std_c, 'std_f': std_f, 'std_lc': std_lc, 'std_lf': std_lf, } def extract_features(self, im, coord): ''' extract coarse and fine level features given the input image and 2d locations :param im: [batch_size, 3, h, w] :param coord: [batch_size, n_pts, 2] :return: coarse features [batch_size, n_pts, coarse_feat_dim] and fine features [batch_size, n_pts, fine_feat_dim] ''' xc, xf = self.net(im) hi, wi = im.size()[2:] coord_n = self.normalize(coord, hi, wi) feat_c = self.sample_feat_by_coord(xc, coord_n) feat_f = self.sample_feat_by_coord(xf, coord_n) return feat_c, feat_f def test(self, im1, im2, coord1): ''' given a pair of images im1, im2, compute the coorrespondences for query points coord1. We performa full image search at coarse level and local search at fine level :param im1: [batch_size, 3, h, w] :param im2: [batch_size, 3, h, w] :param coord1: [batch_size, n_pts, 2] :return: the fine level correspondence location [batch_size, n_pts, 2] ''' xc1, xf1 = self.net(im1) xc2, xf2 = self.net(im2) h1i, w1i = im1.shape[2:] h2i, w2i = im2.shape[2:] coord1_n = self.normalize(coord1, h1i, w1i) feat1_c = self.sample_feat_by_coord(xc1, coord1_n) _, std_c = self.get_expected_correspondence_locs(feat1_c, xc2, with_std=True) coord2_ec_n = self.get_1nn_coord(feat1_c, xc2) feat1_f = self.sample_feat_by_coord(xf1, coord1_n) _, std_f = self.get_expected_correspondence_within_window(feat1_f, xf2, coord2_ec_n, with_std=True) coord2_ef_n = self.get_1nn_coord(feat1_f, xf2) coord2_ef = self.denormalize(coord2_ef_n, h2i, w2i) std = (std_c + std_f)/2 return coord2_ef, std ####################### ResUnet ########################## def class_for_name(module_name, class_name): # load the module, will raise ImportError if module cannot be loaded m = importlib.import_module(module_name) return getattr(m, class_name) class conv(nn.Module): def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): super(conv, self).__init__() self.kernel_size = kernel_size self.conv = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride, padding=(self.kernel_size - 1) // 2) self.bn = nn.BatchNorm2d(num_out_layers) def forward(self, x): return F.elu(self.bn(self.conv(x)), inplace=True) class upconv(nn.Module): def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): super(upconv, self).__init__() self.scale = scale self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) def forward(self, x): x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear') return self.conv(x) class ResUNet(nn.Module): def __init__(self, encoder='resnet50', pretrained=True, coarse_out_ch=128, fine_out_ch=128 ): super(ResUNet, self).__init__() assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type" if encoder in ['resnet18', 'resnet34']: filters = [64, 128, 256, 512] else: filters = [256, 512, 1024, 2048] resnet = class_for_name("torchvision.models", encoder)(pretrained=pretrained) self.firstconv = resnet.conv1 # H/2 self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool # H/4 # encoder self.layer1 = resnet.layer1 # H/4 self.layer2 = resnet.layer2 # H/8 self.layer3 = resnet.layer3 # H/16 # coarse-level conv self.conv_coarse = conv(filters[2], coarse_out_ch, 1, 1) # decoder self.upconv3 = upconv(filters[2], 512, 3, 2) self.iconv3 = conv(filters[1] + 512, 512, 3, 1) self.upconv2 = upconv(512, 256, 3, 2) self.iconv2 = conv(filters[0] + 256, 256, 3, 1) # fine-level conv self.conv_fine = conv(256, fine_out_ch, 1, 1) def skipconnect(self, x1, x2): diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) # for padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return x def forward(self, x): x = self.firstrelu(self.firstbn(self.firstconv(x))) x = self.firstmaxpool(x) x1 = self.layer1(x) x2 = self.layer2(x1) x3 = self.layer3(x2) x_coarse = self.conv_coarse(x3) x = self.upconv3(x3) x = self.skipconnect(x2, x) x = self.iconv3(x) x = self.upconv2(x) x = self.skipconnect(x1, x) x = self.iconv2(x) x_fine = self.conv_fine(x) return [x_coarse, x_fine] ================================================ FILE: LICENSE ================================================ Copyright 2021 Qianqian Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Learning Feature Descriptors using Camera Pose Supervision This repository contains a PyTorch implementation of the paper: [*Learning Feature Descriptors using Camera Pose Supervision*](https://qianqianwang68.github.io/CAPS/) [[Project page]](https://qianqianwang68.github.io/CAPS/) [[Arxiv]](https://arxiv.org/abs/2004.13324) [Qianqian Wang](https://www.cs.cornell.edu/~qqw/), [Xiaowei Zhou](http://www.cad.zju.edu.cn/home/xzhou/), [Bharath Hariharan](http://home.bharathh.info/), [Noah Snavely](http://www.cs.cornell.edu/~snavely/) ECCV 2020 (*Oral*) ![Teaser](assets/teaser.jpg) ## Abstract Recent research on learned visual descriptors has shown promising improvements in correspondence estimation, a key component of many 3D vision tasks. However, existing descriptor learning frameworks typically require ground-truth correspondences between feature points for training, which are challenging to acquire at scale. In this paper we propose a novel weakly-supervised framework that can learn feature descriptors solely from relative camera poses between images. To do so, we devise both a new loss function that exploits the epipolar constraint given by camera poses, and a new model architecture that makes the whole pipeline differentiable and efficient. Because we no longer need pixel-level ground-truth correspondences, our framework opens up the possibility of training on much larger and more diverse datasets for better and unbiased descriptors. We call the resulting descriptors CAmera Pose Supervised, or CAPS, descriptors. Though trained with weak supervision, CAPS descriptors outperform even prior fully-supervised descriptors and achieve state-of-the-art performance on a variety of geometric tasks. ## Requirements ```bash # Create conda environment with torch 1.0.1 and CUDA 10.0 conda env create -f environment.yml conda activate caps ``` If you encounter problems with OpenCV, try to uninstall your current opencv packages and reinstall them again ```bash pip uninstall opencv-python pip uninstall opencv-contrib-python pip install opencv-python==3.4.2.17 pip install opencv-contrib-python==3.4.2.17 ``` ## Pretrained Model Pretrained model can be downloaded using this google drive [link](https://drive.google.com/file/d/1KfIQYyM7vlSvvQ3X7xi1YSPrietqeXYd/view?usp=sharing), or this BaiduYun [link](https://pan.baidu.com/s/1rGt4okK3KsumPLUVhOJdXQ) (password: 3na7). ## Dataset Please download the preprocessed MegaDepth dataset using this google drive [link](https://drive.google.com/file/d/1Xp6BRKx_5sIwdJWVK0kJTmtUcQNMY5yL/view?usp=sharing) or this BaiduYun [link](https://pan.baidu.com/s/1rGt4okK3KsumPLUVhOJdXQ) (password: 3na7 ## Training To start training, please download our [training data](https://drive.google.com/file/d/1Xp6BRKx_5sIwdJWVK0kJTmtUcQNMY5yL/view?usp=sharing), and run the following command: ```bash # example usage python train.py --config configs/train_megadepth.yaml ``` ## Feature extraction We provide code for extracting CAPS descriptors on HPatches dataset. To download and use the HPatches Sequences, please refer to this [link](https://github.com/mihaidusmanu/d2-net/tree/master/hpatches_sequences). To extract CAPS features on HPatches dataset, download the pretrained model, modify paths in ```configs/extract_features_hpatches.yaml``` and run ```bash python extract_features.py --config configs/extract_features_hpatches.yaml ``` ## Interactive demo We provide an interactive demo where you could click on locations in the first image and see their predicted correspondences in the second image. Please refer to ```jupyter/visualization.ipynb``` for more details. ## Cite Please cite our work if you find it useful: ```bibtex @inproceedings{wang2020learning, Title = {Learning Feature Descriptors using Camera Pose Supervision}, Author = {Qianqian Wang and Xiaowei Zhou and Bharath Hariharan and Noah Snavely}, booktitle = {Proc. European Conference on Computer Vision (ECCV)}, Year = {2020}, } ``` **Acknowledgements**. We thank Kai Zhang, Zixin Luo, Zhengqi Li for helpful discussion and comments. This work was partly supported by a DARPA LwLL grant, and in part by the generosity of Eric and Wendy Schmidt by recommendation of the Schmidt Futures program. ================================================ FILE: config.py ================================================ import configargparse def get_args(): parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) parser.add_argument('--config', is_config_file=True, help='config file path') ## path options parser.add_argument('--datadir', type=str, help='the dataset directory') parser.add_argument("--logdir", type=str, default='./logs/', help='dir of tensorboard logs') parser.add_argument("--outdir", type=str, default='./out/', help='dir of output e.g., ckpts') parser.add_argument("--ckpt_path", type=str, default="", help='specific checkpoint path to load the model from, ' 'if not specified, automatically reload from most recent checkpoints') ## general options parser.add_argument("--exp_name", type=str, help='experiment name') parser.add_argument('--n_iters', type=int, default=200000, help='max number of training iterations') parser.add_argument('--phase', type=str, default='train', help='train/val/test') # data options parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--num_pts', type=int, default=500, help='num of points trained in each pair') parser.add_argument('--train_kp', type=str, default='mixed', help='sift/random/mixed') parser.add_argument('--prune_kp', type=int, default=1, help='if prune non-matchable keypoints') # training options parser.add_argument('--batch_size', type=int, default=6, help='input batch size') parser.add_argument('--lr', type=float, default=1e-4, help='base learning rate') parser.add_argument("--lrate_decay_steps", type=int, default=80000, help='decay learning rate by a factor every specified number of steps') parser.add_argument("--lrate_decay_factor", type=float, default=0.5, help='decay learning rate by a factor every specified number of steps') ## model options parser.add_argument('--backbone', type=str, default='resnet50', help='backbone for feature representation extraction. supported: resent') parser.add_argument('--pretrained', type=int, default=1, help='if use ImageNet pretrained weights to initialize the network') parser.add_argument('--coarse_feat_dim', type=int, default=128, help='the feature dimension for coarse level features') parser.add_argument('--fine_feat_dim', type=int, default=128, help='the feature dimension for fine level features') parser.add_argument('--prob_from', type=str, default='correlation', help='compute prob by softmax(correlation score), or softmax(-distance),' 'options: correlation|distance') parser.add_argument('--window_size', type=float, default=0.125, help='the size of the window, w.r.t image width at the fine level') parser.add_argument('--use_nn', type=int, default=1, help='if use nearest neighbor in the coarse level') ## loss function options parser.add_argument('--std', type=int, default=1, help='reweight loss using the standard deviation') parser.add_argument('--w_epipolar_coarse', type=float, default=1, help='coarse level epipolar loss weight') parser.add_argument('--w_epipolar_fine', type=float, default=1, help='fine level epipolar loss weight') parser.add_argument('--w_cycle_coarse', type=float, default=0.1, help='coarse level cycle consistency loss weight') parser.add_argument('--w_cycle_fine', type=float, default=0.1, help='fine level cycle consistency loss weight') parser.add_argument('--w_std', type=float, default=0, help='the weight for the loss on std') parser.add_argument('--th_cycle', type=float, default=0.025, help='if the distance (normalized scale) from the prediction to epipolar line > this th, ' 'do not add the cycle consistency loss') parser.add_argument('--th_epipolar', type=float, default=0.5, help='if the distance (normalized scale) from the prediction to epipolar line > this th, ' 'do not add the epipolar loss') ## logging options parser.add_argument('--log_scalar_interval', type=int, default=20, help='print interval') parser.add_argument('--log_img_interval', type=int, default=500, help='log image interval') parser.add_argument("--save_interval", type=int, default=10000, help='frequency of weight ckpt saving') ## eval options parser.add_argument('--extract_img_dir', type=str, help='the directory of images to extract features') parser.add_argument('--extract_out_dir', type=str, help='the directory of images to extract features') args = parser.parse_known_args()[0] return args ================================================ FILE: configs/extract_features_hpatches.yaml ================================================ # extract feature descriptors using pretrained model weights expname: feature_extraction # replace the following with your pretrained model path, img_dir and out_dir, respectively ckpt_path: 'pretrained/caps-pretrained.pth' extract_img_dir: '/phoenix/S7/qw246/hpatches-benchmark/data/d2net/hpatches-sequences-release/' extract_out_dir: '/phoenix/S7/qw246/hpatches-benchmark/data/output/desc/caps/' ================================================ FILE: configs/train_megadepth.yaml ================================================ # training from scratch using a single gpu, default configs in config.py file exp_name: train_caps datadir: /phoenix/S7/qw246/CAPS-MegaDepth-release-light # replace with your data dir ================================================ FILE: dataloader/__init__.py ================================================ ================================================ FILE: dataloader/data_utils.py ================================================ import numpy as np import cv2 def skew(x): return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) def rotateImage(image, angle): h, w = image.shape[:2] angle_radius = np.abs(angle / 180. * np.pi) cos = np.cos(angle_radius) sin = np.sin(angle_radius) tan = np.tan(angle_radius) scale_h = (h / cos + (w - h * tan) * sin) / h scale_w = (h / sin + (w - h / tan) * cos) / w scale = max(scale_h, scale_w) image_center = tuple(np.array(image.shape[1::-1]) / 2.) rot_mat = cv2.getRotationMatrix2D(image_center, angle, scale) result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR) rotation = np.eye(4) rotation[:2, :2] = rot_mat[:2, :2] return result, rotation def perspective_transform(img, param=0.001): h, w = img.shape[:2] random_state = np.random.RandomState(None) M = np.array([[1 - param + 2 * param * random_state.rand(), -param + 2 * param * random_state.rand(), -param + 2 * param * random_state.rand()], [-param + 2 * param * random_state.rand(), 1 - param + 2 * param * random_state.rand(), -param + 2 * param * random_state.rand()], [-param + 2 * param * random_state.rand(), -param + 2 * param * random_state.rand(), 1 - param + 2 * param * random_state.rand()]]) dst = cv2.warpPerspective(img, M, (w, h)) return dst, M def generate_query_kpts(img, mode, num_pts, h, w): # generate candidate query points if mode == 'random': kp1_x = np.random.rand(num_pts) * (w - 1) kp1_y = np.random.rand(num_pts) * (h - 1) coord = np.stack((kp1_x, kp1_y)).T elif mode == 'sift': gray1 = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) sift = cv2.xfeatures2d.SIFT_create(nfeatures=num_pts) kp1 = sift.detect(gray1) coord = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1]) elif mode == 'mixed': kp1_x = np.random.rand(1 * int(0.1 * num_pts)) * (w - 1) kp1_y = np.random.rand(1 * int(0.1 * num_pts)) * (h - 1) kp1_rand = np.stack((kp1_x, kp1_y)).T sift = cv2.xfeatures2d.SIFT_create(nfeatures=int(0.9 * num_pts)) gray1 = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) kp1_sift = sift.detect(gray1) kp1_sift = np.array([[kp.pt[0], kp.pt[1]] for kp in kp1_sift]) if len(kp1_sift) == 0: coord = kp1_rand else: coord = np.concatenate((kp1_rand, kp1_sift), 0) else: raise Exception('unknown type of keypoints') return coord def prune_kpts(coord1, F_gt, im2_size, intrinsic1, intrinsic2, pose, d_min, d_max): # compute the epipolar lines corresponding to coord1 coord1_h = np.concatenate([coord1, np.ones_like(coord1[:, [0]])], axis=1).T # 3xn epipolar_line = F_gt.dot(coord1_h) # 3xn epipolar_line /= np.clip(np.linalg.norm(epipolar_line[:2], axis=0), a_min=1e-10, a_max=None) # 3xn # determine whether the epipolar lines intersect with the second image h2, w2 = im2_size corners = np.array([[0, 0, 1], [0, h2 - 1, 1], [w2 - 1, 0, 1], [w2 - 1, h2 - 1, 1]]) # 4x3 dists = np.abs(corners.dot(epipolar_line)) # if the epipolar line is far away from any image corners than sqrt(h^2+w^2) # it doesn't intersect with the image non_intersect = (dists > np.sqrt(w2 ** 2 + h2 ** 2)).any(axis=0) # determine if points in coord1 is likely to have correspondence in the other image by the rough depth range intrinsic1_4x4 = np.eye(4) intrinsic1_4x4[:3, :3] = intrinsic1 intrinsic2_4x4 = np.eye(4) intrinsic2_4x4[:3, :3] = intrinsic2 coord1_h_min = np.concatenate([d_min * coord1, d_min * np.ones_like(coord1[:, [0]]), np.ones_like(coord1[:, [0]])], axis=1).T coord1_h_max = np.concatenate([d_max * coord1, d_max * np.ones_like(coord1[:, [0]]), np.ones_like(coord1[:, [0]])], axis=1).T coord2_h_min = intrinsic2_4x4.dot(pose).dot(np.linalg.inv(intrinsic1_4x4)).dot(coord1_h_min) coord2_h_max = intrinsic2_4x4.dot(pose).dot(np.linalg.inv(intrinsic1_4x4)).dot(coord1_h_max) coord2_min = coord2_h_min[:2] / (coord1_h_min[2] + 1e-10) coord2_max = coord2_h_max[:2] / (coord1_h_max[2] + 1e-10) out_range = ((coord2_min[0] < 0) & (coord2_max[0] < 0)) | \ ((coord2_min[1] < 0) & (coord2_max[1] < 0)) | \ ((coord2_min[0] > w2 - 1) & (coord2_max[0] > w2 - 1)) | \ ((coord2_min[1] > h2 - 1) & (coord2_max[1] > h2 - 1)) ind_intersect = ~(non_intersect | out_range) return ind_intersect ================================================ FILE: dataloader/megadepth.py ================================================ import torch from torch.utils.data import Dataset import os import numpy as np import cv2 import skimage.io as io import torchvision.transforms as transforms import utils import collections from tqdm import tqdm import dataloader.data_utils as data_utils rand = np.random.RandomState(234) class MegaDepthLoader(): def __init__(self, args): self.args = args self.dataset = MegaDepth(args) self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, collate_fn=self.my_collate) def my_collate(self, batch): ''' Puts each data field into a tensor with outer dimension batch size ''' batch = list(filter(lambda b: b is not None, batch)) return torch.utils.data.dataloader.default_collate(batch) def load_data(self): return self.data_loader def name(self): return 'MegaDepthLoader' def __len__(self): return len(self.dataset) class MegaDepth(Dataset): def __init__(self, args): self.args = args if args.phase == 'train': # augment during training self.transform = transforms.Compose([transforms.ToPILImage(), transforms.ColorJitter (brightness=1, contrast=1, saturation=1, hue=0.4), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) else: self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) self.phase = args.phase self.root = os.path.join(args.datadir, self.phase) self.images = self.read_img_cam() self.imf1s, self.imf2s = self.read_pairs() print('total number of image pairs loaded: {}'.format(len(self.imf1s))) # shuffle data index = np.arange(len(self.imf1s)) rand.shuffle(index) self.imf1s = list(np.array(self.imf1s)[index]) self.imf2s = list(np.array(self.imf2s)[index]) def read_img_cam(self): images = {} Image = collections.namedtuple( "Image", ["name", "w", "h", "fx", "fy", "cx", "cy", "rvec", "tvec"]) for scene_id in os.listdir(self.root): densefs = [f for f in os.listdir(os.path.join(self.root, scene_id)) if 'dense' in f and os.path.isdir(os.path.join(self.root, scene_id, f))] for densef in densefs: folder = os.path.join(self.root, scene_id, densef, 'aligned') img_cam_txt_path = os.path.join(folder, 'img_cam.txt') with open(img_cam_txt_path, "r") as fid: while True: line = fid.readline() if not line: break line = line.strip() if len(line) > 0 and line[0] != "#": elems = line.split() image_name = elems[0] img_path = os.path.join(folder, 'images', image_name) w, h = int(elems[1]), int(elems[2]) fx, fy = float(elems[3]), float(elems[4]) cx, cy = float(elems[5]), float(elems[6]) R = np.array(elems[7:16]) T = np.array(elems[16:19]) images[img_path] = Image( name=image_name, w=w, h=h, fx=fx, fy=fy, cx=cx, cy=cy, rvec=R, tvec=T ) return images def read_pairs(self): imf1s, imf2s = [], [] print('reading image pairs from {}...'.format(self.root)) for scene_id in tqdm(os.listdir(self.root), desc='# loading data from scene folders'): densefs = [f for f in os.listdir(os.path.join(self.root, scene_id)) if 'dense' in f and os.path.isdir(os.path.join(self.root, scene_id, f))] for densef in densefs: imf1s_ = [] imf2s_ = [] folder = os.path.join(self.root, scene_id, densef, 'aligned') pairf = os.path.join(folder, 'pairs.txt') if os.path.exists(pairf): f = open(pairf, 'r') for line in f: imf1, imf2 = line.strip().split(' ') imf1s_.append(os.path.join(folder, 'images', imf1)) imf2s_.append(os.path.join(folder, 'images', imf2)) # make # image pairs per scene more balanced if len(imf1s_) > 5000: index = np.arange(len(imf1s_)) rand.shuffle(index) imf1s_ = list(np.array(imf1s_)[index[:5000]]) imf2s_ = list(np.array(imf2s_)[index[:5000]]) imf1s.extend(imf1s_) imf2s.extend(imf2s_) return imf1s, imf2s @staticmethod def get_intrinsics(im_meta): return np.array([[im_meta.fx, 0, im_meta.cx], [0, im_meta.fy, im_meta.cy], [0, 0, 1]]) @staticmethod def get_extrinsics(im_meta): R = im_meta.rvec.reshape(3, 3) t = im_meta.tvec extrinsic = np.eye(4) extrinsic[:3, :3] = R extrinsic[:3, 3] = t return extrinsic def __getitem__(self, item): imf1 = self.imf1s[item] imf2 = self.imf2s[item] im1_meta = self.images[imf1] im2_meta = self.images[imf2] im1 = io.imread(imf1) im2 = io.imread(imf2) h, w = im1.shape[:2] intrinsic1 = self.get_intrinsics(im1_meta) intrinsic2 = self.get_intrinsics(im2_meta) extrinsic1 = self.get_extrinsics(im1_meta) extrinsic2 = self.get_extrinsics(im2_meta) relative = extrinsic2.dot(np.linalg.inv(extrinsic1)) R = relative[:3, :3] # remove pairs that have a relative rotation angle larger than 80 degrees theta = np.arccos(np.clip((np.trace(R) - 1) / 2, -1, 1)) * 180 / np.pi if theta > 80 and self.phase == 'train': return None T = relative[:3, 3] tx = data_utils.skew(T) E_gt = np.dot(tx, R) F_gt = np.linalg.inv(intrinsic2).T.dot(E_gt).dot(np.linalg.inv(intrinsic1)) # generate candidate query points coord1 = data_utils.generate_query_kpts(im1, self.args.train_kp, 10*self.args.num_pts, h, w) # if no keypoints are detected if len(coord1) == 0: return None # prune query keypoints that are not likely to have correspondence in the other image if self.args.prune_kp: ind_intersect = data_utils.prune_kpts(coord1, F_gt, im2.shape[:2], intrinsic1, intrinsic2, relative, d_min=4, d_max=400) if np.sum(ind_intersect) == 0: return None coord1 = coord1[ind_intersect] coord1 = utils.random_choice(coord1, self.args.num_pts) coord1 = torch.from_numpy(coord1).float() im1_ori, im2_ori = torch.from_numpy(im1), torch.from_numpy(im2) F_gt = torch.from_numpy(F_gt).float() / (F_gt[-1, -1] + 1e-10) intrinsic1 = torch.from_numpy(intrinsic1).float() intrinsic2 = torch.from_numpy(intrinsic2).float() pose = torch.from_numpy(relative[:3, :]).float() im1_tensor = self.transform(im1) im2_tensor = self.transform(im2) out = {'im1': im1_tensor, 'im2': im2_tensor, 'im1_ori': im1_ori, 'im2_ori': im2_ori, 'pose': pose, 'F': F_gt, 'intrinsic1': intrinsic1, 'intrinsic2': intrinsic2, 'coord1': coord1} return out def __len__(self): return len(self.imf1s) ================================================ FILE: environment.yml ================================================ name: caps channels: - pytorch - defaults dependencies: - _libgcc_mutex=0.1=main - blas=1.0=mkl - ca-certificates=2020.7.22=0 - certifi=2020.6.20=py37_0 - cffi=1.14.1=py37he30daa8_0 - cudatoolkit=10.0.130=0 - freetype=2.10.2=h5ab3b9f_0 - intel-openmp=2020.1=217 - jpeg=9b=h024ee3a_2 - lcms2=2.11=h396b838_0 - ld_impl_linux-64=2.33.1=h53a641e_7 - libedit=3.1.20191231=h14c3975_1 - libffi=3.3=he6710b0_2 - libgcc-ng=9.1.0=hdf63c60_0 - libpng=1.6.37=hbc83047_0 - libstdcxx-ng=9.1.0=hdf63c60_0 - libtiff=4.1.0=h2733197_1 - lz4-c=1.9.2=he6710b0_1 - mkl=2020.1=217 - mkl-service=2.3.0=py37he904b0f_0 - mkl_fft=1.1.0=py37h23d657b_0 - mkl_random=1.1.1=py37h0573a6f_0 - ncurses=6.2=he6710b0_1 - ninja=1.10.0=py37hfd86e86_0 - numpy=1.19.1=py37hbc911f0_0 - numpy-base=1.19.1=py37hfa32c7d_0 - olefile=0.46=py37_0 - openssl=1.1.1g=h7b6447c_0 - pillow=7.2.0=py37hb39fc2d_0 - pip=20.2.2=py37_0 - pycparser=2.20=py_2 - python=3.7.7=hcff3b4d_5 - pytorch=1.0.1=py3.7_cuda10.0.130_cudnn7.4.2_2 - readline=8.0=h7b6447c_0 - setuptools=49.6.0=py37_0 - six=1.15.0=py_0 - sqlite=3.33.0=h62c20be_0 - tk=8.6.10=hbc83047_0 - torchvision=0.2.2=py_3 - wheel=0.34.2=py37_0 - xz=5.2.5=h7b6447c_0 - zlib=1.2.11=h7b6447c_3 - zstd=1.4.5=h9ceee32_0 - pip: - chardet==3.0.4 - configargparse==1.2.3 - cycler==0.10.0 - decorator==4.4.2 - filelock==3.0.12 - gdown==3.12.2 - idna==2.10 - imageio==2.9.0 - kiwisolver==1.2.0 - matplotlib==3.3.1 - networkx==2.5 - opencv-contrib-python==3.4.2.17 - opencv-python==3.4.2.17 - protobuf==3.13.0 - pyparsing==2.4.7 - pysocks==1.7.1 - python-dateutil==2.8.1 - pywavelets==1.1.1 - pyyaml==5.3.1 - requests==2.24.0 - scikit-image==0.15.0 - scipy==1.3.1 - tensorboardx==2.1 - tqdm==4.48.2 - urllib3==1.25.10 prefix: /phoenix/S7/qw246/anaconda3/envs/caps ================================================ FILE: extract_features.py ================================================ import torch from torch.utils.data import Dataset import os import numpy as np import cv2 import skimage.io as io import torchvision.transforms as transforms import config from tqdm import tqdm from CAPS.caps_model import CAPSModel class HPatchDataset(Dataset): def __init__(self, imdir): self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) self.imfs = [] for f in os.listdir(imdir): scene_dir = os.path.join(imdir, f) self.imfs.extend([os.path.join(scene_dir, '{}.ppm').format(ind) for ind in range(1, 7)]) def __getitem__(self, item): imf = self.imfs[item] im = io.imread(imf) im_tensor = self.transform(im) # using sift keypoints sift = cv2.xfeatures2d.SIFT_create() gray = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) kpts = sift.detect(gray) kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts]) coord = torch.from_numpy(kpts).float() out = {'im': im_tensor, 'coord': coord, 'imf': imf} return out def __len__(self): return len(self.imfs) if __name__ == '__main__': # example code for extracting features for HPatches dataset, SIFT keypoint is used args = config.get_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' dataset = HPatchDataset(args.extract_img_dir) data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers) model = CAPSModel(args) outdir = args.extract_out_dir os.makedirs(outdir, exist_ok=True) with torch.no_grad(): for data in tqdm(data_loader): im = data['im'].to(device) img_path = data['imf'][0] coord = data['coord'].to(device) feat_c, feat_f = model.extract_features(im, coord) desc = torch.cat((feat_c, feat_f), -1).squeeze(0).detach().cpu().numpy() kpt = coord.cpu().numpy().squeeze(0) out_path = os.path.join(outdir, '{}-{}'.format(os.path.basename(os.path.dirname(img_path)), os.path.basename(img_path), )) with open(out_path + '.caps', 'wb') as output_file: np.savez( output_file, keypoints=kpt, scores=[], descriptors=desc ) ================================================ FILE: jupyter/functions.py ================================================ import numpy as np import matplotlib.pyplot as plt import matplotlib.lines as mlines import torch import sys sys.path.append('../') from dataloader import megadepth import torch.utils.data from CAPS.caps_model import CAPSModel import cv2 class Visualization(object): def __init__(self, args): dataset = megadepth.MegaDepth(args) self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) self.model = CAPSModel(args) self.loader_iter = iter(self.dataloader) def random_sample(self): self.sample = next(self.loader_iter) def plot_img_pair(self, with_std=False, with_epipline=False): self.coords = [] self.colors = [] self.with_std = with_std self.with_epipline = with_epipline im1 = self.sample['im1_ori'] im2 = self.sample['im2_ori'] self.h, self.w = im1.shape[1], im1.shape[2] im1 = im1.squeeze().cpu().numpy() im2 = im2.squeeze().cpu().numpy() blank = np.ones((self.h, 5, 3)) * 255 out = np.concatenate((im1, blank, im2), 1).astype(np.uint8) self.fig = plt.figure(figsize=(12, 5)) self.ax = self.fig.add_subplot(111) self.ax.imshow(out) self.ax.axis('off') plt.tight_layout() cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick) def onclick(self, event): color = tuple(np.random.rand(3).tolist()) coord = [event.xdata, event.ydata] self.coord = coord self.color = color self.coords.append(coord) self.colors.append(color) self.ax.scatter(event.xdata, event.ydata, c=color) self.find_correspondence() self.plot_correspondence() def find_correspondence(self): data_in = self.sample data_in['coord1'] = torch.from_numpy(np.array(self.coord)).float().cuda().unsqueeze(0).unsqueeze(0) data_in['coord2'] = data_in['coord1'] self.model.set_input(data_in) coord2_e, std = self.model.test() self.correspondence = coord2_e.squeeze().cpu().numpy() self.std = std.squeeze().cpu().numpy() def plot_correspondence(self): point1 = self.coord point2 = self.correspondence point2[0] += self.w + 5 self.ax.scatter(point2[0], point2[1], color=self.color) if self.with_std: circle = plt.Circle((point2[0], point2[1]), radius=100 * self.std, fill=False, color=self.color) self.ax.add_patch(circle) if self.with_epipline: line2 = cv2.computeCorrespondEpilines(np.array(point1).reshape(-1, 1, 2), 1, self.sample['F'].squeeze().cpu().numpy()) line2 = np.array(line2).squeeze() intersection = np.array([[0, -line2[2]/line2[1]], [-line2[2]/line2[0], 0], [self.w-1, -(line2[2]+line2[0]*(self.w-1))/line2[1]], [-(line2[1]*(self.h-1)+line2[2])/line2[0], self.h-1]]) valid = (intersection[:, 0] >= 0) & (intersection[:, 0] <= self.w-1) & \ (intersection[:, 1] >= 0) & (intersection[:, 1] <= self.h-1) if np.sum(valid) == 2: intersection = intersection[valid].astype(int) x0, y0 = intersection[0] x1, y1 = intersection[1] l = mlines.Line2D([x0+self.w + 5, x1+self.w + 5], [y0, y1], color=self.color) self.ax.add_line(l) plt.show() ================================================ FILE: jupyter/visualization.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive demo\n", "\n", "This interactive demo shows the correspondences our method obtain by **densely** searching in the image space.\n", "\n", "\n", "Run the first code block and wait for it to complete ..." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "# loading data from scene folders: 0%| | 0/37 [00:00');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('