Repository: andreasveit/conditional-similarity-networks Branch: master Commit: 6021dfe5a7f3 Files: 9 Total size: 30.0 KB Directory structure: gitextract_xmmzkyrb/ ├── LICENSE ├── README.md ├── Resnet_18.py ├── csn.py ├── get_data.py ├── main.py ├── requirements.txt ├── triplet_image_loader.py └── tripletnet.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright (c) 2017, Andreas Veit All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # Conditional Similarity Networks (CSNs) This repository contains a [PyTorch](http://pytorch.org/) implementation of the paper [Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) presented at CVPR 2017. The code is based on the [PyTorch example for training ResNet on Imagenet](https://github.com/pytorch/examples/tree/master/imagenet) and the [Triplet Network example](https://github.com/andreasveit/triplet-network-pytorch). ## Table of Contents 0. [Introduction](#introduction) 0. [Usage](#usage) 0. [Citing](#citing) 0. [Contact](#contact) ## Introduction What makes images similar? To measure the similarity between images, they are typically embedded in a feature-vector space, in which their distance preserve the relative dissimilarity. However, when learning such similarity embeddings the simplifying assumption is commonly made that images are only compared to one unique measure of similarity. [Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) address this shortcoming by learning a nonlinear embeddings that gracefully deals with multiple notions of similarity within a shared embedding. Different aspects of similarity are incorporated by assigning responsibility weights to each embedding dimension with respect to each aspect of similarity. Images are passed through a convolutional network and projected into a nonlinear embedding such that different dimensions encode features for specific notions of similarity. Subsequent masks indicate which dimensions of the embedding are responsible for separate aspects of similarity. We can then compare objects according to various notions of similarity by selecting an appropriate masked subspace. ## Usage The detault setting for this repo is a CSN with fixed masks, an embedding dimension 64 and four notions of similarity. You can download the Zappos dataset as well as the training, validation and test triplets used in the paper with ```sh python get_data.py ``` The network can be simply trained with `python main.py` or with optional arguments for different hyperparameters: ```sh $ python main.py --name {your experiment name} --learned --num_traintriplets 200000 ``` Training progress can be easily tracked with [visdom](https://github.com/facebookresearch/visdom) using the `--visdom` flag. It keeps track of the learning rate, loss, training and validation accuracy both for all triplets as well as separated for each notion of similarity, the embedding norm, mask norm as well as the masks. By default the training code keeps track of the model with the highest performance on the validation set. Thus, after the model has converged, it can be directly evaluated on the test set as follows ```sh $ python main.py --test --resume runs/{your experiment name}/model_best.pth.tar ``` ## Citing If you find this helps your research, please consider citing: ``` @conference{Veit2017, title = {Conditional Similarity Networks}, author = {Andreas Veit and Serge Belongie and Theofanis Karaletsos}, year = {2017}, journal = {Computer Vision and Pattern Recognition (CVPR)}, } ``` ## Contact andreas at cs dot cornell dot edu Any discussions, suggestions and questions are welcome! ================================================ FILE: Resnet_18.py ================================================ import torch.nn as nn import math import torch.utils.model_zoo as model_zoo __all__ = ['ResNet', 'resnet18'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', } def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, embedding_size=64): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.avgpool = nn.AvgPool2d(7) self.fc_embed = nn.Linear(256 * block.expansion, embedding_size) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc_embed(x) return x def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [2, 2, 2], **kwargs) if pretrained: state = model.state_dict() loaded_state_dict = model_zoo.load_url(model_urls['resnet18']) for k in loaded_state_dict: if k in state: state[k] = loaded_state_dict[k] model.load_state_dict(state) return model ================================================ FILE: csn.py ================================================ import torch import torch.nn as nn import numpy as np class ConditionalSimNet(nn.Module): def __init__(self, embeddingnet, n_conditions, embedding_size, learnedmask=True, prein=False): """ embeddingnet: The network that projects the inputs into an embedding of embedding_size n_conditions: Integer defining number of different similarity notions embedding_size: Number of dimensions of the embedding output from the embeddingnet learnedmask: Boolean indicating whether masks are learned or fixed prein: Boolean indicating whether masks are initialized in equally sized disjoint sections or random otherwise""" super(ConditionalSimNet, self).__init__() self.learnedmask = learnedmask self.embeddingnet = embeddingnet # create the mask if learnedmask: if prein: # define masks self.masks = torch.nn.Embedding(n_conditions, embedding_size) # initialize masks mask_array = np.zeros([n_conditions, embedding_size]) mask_array.fill(0.1) mask_len = int(embedding_size / n_conditions) for i in range(n_conditions): mask_array[i, i*mask_len:(i+1)*mask_len] = 1 # no gradients for the masks self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=True) else: # define masks with gradients self.masks = torch.nn.Embedding(n_conditions, embedding_size) # initialize weights self.masks.weight.data.normal_(0.9, 0.7) # 0.1, 0.005 else: # define masks self.masks = torch.nn.Embedding(n_conditions, embedding_size) # initialize masks mask_array = np.zeros([n_conditions, embedding_size]) mask_len = int(embedding_size / n_conditions) for i in range(n_conditions): mask_array[i, i*mask_len:(i+1)*mask_len] = 1 # no gradients for the masks self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=False) def forward(self, x, c): embedded_x = self.embeddingnet(x) self.mask = self.masks(c) if self.learnedmask: self.mask = torch.nn.functional.relu(self.mask) masked_embedding = embedded_x * self.mask return masked_embedding, self.mask.norm(1), embedded_x.norm(2), masked_embedding.norm(2) ================================================ FILE: get_data.py ================================================ import urllib.request import os import os.path import zipfile if not os.path.exists(os.path.join('data')): os.makedirs('data') if os.path.exists(os.path.join('data', 'ut-zap50k-images')): pass else: urllib.request.urlretrieve("http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip", filename="data/ut-zap50k-imgs.zip") zip_ref = zipfile.ZipFile("data/ut-zap50k-imgs.zip", 'r') zip_ref.extractall("data") zip_ref.close() os.remove("data/ut-zap50k-imgs.zip") if os.path.exists(os.path.join('data', 'tripletlists')): pass else: urllib.request.urlretrieve("https://vision.cornell.edu/se3/wp-content/uploads/2019/05/csn_zappos_triplets.zip", filename="data/triplets.zip") zip_ref = zipfile.ZipFile("data/triplets.zip", 'r') zip_ref.extractall("data") zip_ref.close() os.remove("data/triplets.zip") ================================================ FILE: main.py ================================================ from __future__ import print_function import argparse import os import sys import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import transforms from torch.autograd import Variable import torch.backends.cudnn as cudnn from triplet_image_loader import TripletImageLoader from tripletnet import CS_Tripletnet from visdom import Visdom import numpy as np import Resnet_18 from csn import ConditionalSimNet # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=256, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') parser.add_argument('--start_epoch', type=int, default=1, metavar='N', help='number of start epoch (default: 1)') parser.add_argument('--lr', type=float, default=5e-5, metavar='LR', help='learning rate (default: 5e-5)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--log-interval', type=int, default=20, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--margin', type=float, default=0.2, metavar='M', help='margin for triplet loss (default: 0.2)') parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)') parser.add_argument('--name', default='Conditional_Similarity_Network', type=str, help='name of experiment') parser.add_argument('--embed_loss', type=float, default=5e-3, metavar='M', help='parameter for loss for embedding norm') parser.add_argument('--mask_loss', type=float, default=5e-4, metavar='M', help='parameter for loss for mask norm') parser.add_argument('--num_traintriplets', type=int, default=100000, metavar='N', help='how many unique training triplets (default: 100000)') parser.add_argument('--dim_embed', type=int, default=64, metavar='N', help='how many dimensions in embedding (default: 64)') parser.add_argument('--test', dest='test', action='store_true', help='To only run inference on test set') parser.add_argument('--learned', dest='learned', action='store_true', help='To learn masks from random initialization') parser.add_argument('--prein', dest='prein', action='store_true', help='To initialize masks to be disjoint') parser.add_argument('--visdom', dest='visdom', action='store_true', help='Use visdom to track and plot') parser.add_argument('--conditions', nargs='*', type=int, help='Set of similarity notions') parser.set_defaults(test=False) parser.set_defaults(learned=False) parser.set_defaults(prein=False) parser.set_defaults(visdom=False) best_acc = 0 def main(): global args, best_acc args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.visdom: global plotter plotter = VisdomLinePlotter(env_name=args.name) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) global conditions if args.conditions is not None: conditions = args.conditions else: conditions = [0,1,2,3] kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', conditions, 'train', n_triplets=args.num_traintriplets, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', conditions, 'test', n_triplets=160000, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', conditions, 'val', n_triplets=80000, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, **kwargs) model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed) csn_model = ConditionalSimNet(model, n_conditions=len(conditions), embedding_size=args.dim_embed, learnedmask=args.learned, prein=args.prein) global mask_var mask_var = csn_model.masks.weight tnet = CS_Tripletnet(csn_model) if args.cuda: tnet.cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] tnet.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True criterion = torch.nn.MarginRankingLoss(margin = args.margin) parameters = filter(lambda p: p.requires_grad, tnet.parameters()) optimizer = optim.Adam(parameters, lr=args.lr) n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) print(' + Number of params: {}'.format(n_parameters)) if args.test: test_acc = test(test_loader, tnet, criterion, 1) sys.exit() for epoch in range(args.start_epoch, args.epochs + 1): # update learning rate adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, tnet, criterion, optimizer, epoch) # evaluate on validation set acc = test(val_loader, tnet, criterion, epoch) # remember best acc and save checkpoint is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': tnet.state_dict(), 'best_prec1': best_acc, }, is_best) def train(train_loader, tnet, criterion, optimizer, epoch): losses = AverageMeter() accs = AverageMeter() emb_norms = AverageMeter() mask_norms = AverageMeter() # switch to train mode tnet.train() for batch_idx, (data1, data2, data3, c) in enumerate(train_loader): if args.cuda: data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda() data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c) # compute output dista, distb, mask_norm, embed_norm, mask_embed_norm = tnet(data1, data2, data3, c) # 1 means, dista should be larger than distb target = torch.FloatTensor(dista.size()).fill_(1) if args.cuda: target = target.cuda() target = Variable(target) loss_triplet = criterion(dista, distb, target) loss_embedd = embed_norm / np.sqrt(data1.size(0)) loss_mask = mask_norm / data1.size(0) loss = loss_triplet + args.embed_loss * loss_embedd + args.mask_loss * loss_mask # measure accuracy and record loss acc = accuracy(dista, distb) losses.update(loss_triplet.data.item(), data1.size(0)) accs.update(acc, data1.size(0)) emb_norms.update(loss_embedd.data.item()) mask_norms.update(loss_mask.data.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{}]\t' 'Loss: {:.4f} ({:.4f}) \t' 'Acc: {:.2f}% ({:.2f}%) \t' 'Emb_Norm: {:.2f} ({:.2f})'.format( epoch, batch_idx * len(data1), len(train_loader.dataset), losses.val, losses.avg, 100. * accs.val, 100. * accs.avg, emb_norms.val, emb_norms.avg)) # log avg values to visdom if args.visdom: plotter.plot('acc', 'train', epoch, accs.avg) plotter.plot('loss', 'train', epoch, losses.avg) plotter.plot('emb_norms', 'train', epoch, emb_norms.avg) plotter.plot('mask_norms', 'train', epoch, mask_norms.avg) if epoch % 10 == 0: plotter.plot_mask(torch.nn.functional.relu(mask_var).data.cpu().numpy().T, epoch) def test(test_loader, tnet, criterion, epoch): losses = AverageMeter() accs = AverageMeter() accs_cs = {} for condition in conditions: accs_cs[condition] = AverageMeter() # switch to evaluation mode tnet.eval() for batch_idx, (data1, data2, data3, c) in enumerate(test_loader): if args.cuda: data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda() data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c) c_test = c # compute output dista, distb, _, _, _ = tnet(data1, data2, data3, c) target = torch.FloatTensor(dista.size()).fill_(1) if args.cuda: target = target.cuda() target = Variable(target) test_loss = criterion(dista, distb, target).data.item() # measure accuracy and record loss acc = accuracy(dista, distb) accs.update(acc, data1.size(0)) for condition in conditions: accs_cs[condition].update(accuracy_id(dista, distb, c_test, condition), data1.size(0)) losses.update(test_loss, data1.size(0)) print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format( losses.avg, 100. * accs.avg)) if args.visdom: for condition in conditions: plotter.plot('accs', 'acc_{}'.format(condition), epoch, accs_cs[condition].avg) plotter.plot(args.name, args.name, epoch, accs.avg, env='overview') plotter.plot('acc', 'test', epoch, accs.avg) plotter.plot('loss', 'test', epoch, losses.avg) return accs.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): """Saves checkpoint to disk""" directory = "runs/%s/"%(args.name) if not os.path.exists(directory): os.makedirs(directory) filename = directory + filename torch.save(state, filename) if is_best: shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar') class VisdomLinePlotter(object): """Plots to Visdom""" def __init__(self, env_name='main'): self.viz = Visdom() self.env = env_name self.plots = {} def plot(self, var_name, split_name, x, y, env=None): if env is not None: print_env = env else: print_env = self.env if var_name not in self.plots: self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=print_env, opts=dict( legend=[split_name], title=var_name, xlabel='Epochs', ylabel=var_name )) else: self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name) def plot_mask(self, masks, epoch): self.viz.bar( X=masks, env=self.env, opts=dict( stacked=True, title=epoch, ) ) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr * ((1 - 0.015) ** epoch) if args.visdom: plotter.plot('lr', 'learning rate', epoch, lr) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(dista, distb): margin = 0 pred = (dista - distb - margin).cpu().data return (pred > 0).sum()*1.0/dista.size()[0] def accuracy_id(dista, distb, c, c_id): margin = 0 pred = (dista - distb - margin).cpu().data return ((pred > 0)*(c.cpu().data == c_id)).sum()*1.0/(c.cpu().data == c_id).sum() if __name__ == '__main__': main() ================================================ FILE: requirements.txt ================================================ torch torchvision visdom ================================================ FILE: triplet_image_loader.py ================================================ from PIL import Image import os import os.path import torch.utils.data import torchvision.transforms as transforms import numpy as np filenames = {'train': ['class_tripletlist_train.txt', 'closure_tripletlist_train.txt', 'gender_tripletlist_train.txt', 'heel_tripletlist_train.txt'], 'val': ['class_tripletlist_val.txt', 'closure_tripletlist_val.txt', 'gender_tripletlist_val.txt', 'heel_tripletlist_val.txt'], 'test': ['class_tripletlist_test.txt', 'closure_tripletlist_test.txt', 'gender_tripletlist_test.txt', 'heel_tripletlist_test.txt']} def default_image_loader(path): return Image.open(path).convert('RGB') class TripletImageLoader(torch.utils.data.Dataset): def __init__(self, root, base_path, filenames_filename, conditions, split, n_triplets, transform=None, loader=default_image_loader): """ filenames_filename: A text file with each line containing the path to an image e.g., images/class1/sample.jpg triplets_file_name: A text file with each line containing three integers, where integer i refers to the i-th image in the filenames file. For a line of intergers 'a b c', a triplet is defined such that image a is more similar to image c than it is to image b, e.g., 0 2017 42 """ self.root = root self.base_path = base_path self.filenamelist = [] for line in open(os.path.join(self.root, filenames_filename)): self.filenamelist.append(line.rstrip('\n')) triplets = [] if split == 'train': fnames = filenames['train'] elif split == 'val': fnames = filenames['val'] else: fnames = filenames['test'] for condition in conditions: for line in open(os.path.join(self.root, 'tripletlists', fnames[condition])): triplets.append((line.split()[0], line.split()[1], line.split()[2], condition)) # anchor, far, close # print(triplets[:100]) np.random.shuffle(triplets) # print(triplets[:100]) self.triplets = triplets[:int(n_triplets * 1.0 * len(conditions) / 4)] self.transform = transform self.loader = loader def __getitem__(self, index): path1, path2, path3, c = self.triplets[index] if os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])): img1 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) img2 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path2)])) img3 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path3)])) if self.transform is not None: img1 = self.transform(img1) img2 = self.transform(img2) img3 = self.transform(img3) return img1, img2, img3, c else: return None def __len__(self): return len(self.triplets) ================================================ FILE: tripletnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class CS_Tripletnet(nn.Module): def __init__(self, embeddingnet): super(CS_Tripletnet, self).__init__() self.embeddingnet = embeddingnet def forward(self, x, y, z, c): """ x: Anchor image, y: Distant (negative) image, z: Close (positive) image, c: Integer indicating according to which notion of similarity images are compared""" embedded_x, masknorm_norm_x, embed_norm_x, tot_embed_norm_x = self.embeddingnet(x, c) embedded_y, masknorm_norm_y, embed_norm_y, tot_embed_norm_y = self.embeddingnet(y, c) embedded_z, masknorm_norm_z, embed_norm_z, tot_embed_norm_z = self.embeddingnet(z, c) mask_norm = (masknorm_norm_x + masknorm_norm_y + masknorm_norm_z) / 3 embed_norm = (embed_norm_x + embed_norm_y + embed_norm_z) / 3 mask_embed_norm = (tot_embed_norm_x + tot_embed_norm_y + tot_embed_norm_z) / 3 dist_a = F.pairwise_distance(embedded_x, embedded_y, 2) dist_b = F.pairwise_distance(embedded_x, embedded_z, 2) return dist_a, dist_b, mask_norm, embed_norm, mask_embed_norm