master 6021dfe5a7f3 cached
9 files
30.0 KB
7.4k tokens
35 symbols
1 requests
Download .txt
Repository: andreasveit/conditional-similarity-networks
Branch: master
Commit: 6021dfe5a7f3
Files: 9
Total size: 30.0 KB

Directory structure:
gitextract_xmmzkyrb/

├── LICENSE
├── README.md
├── Resnet_18.py
├── csn.py
├── get_data.py
├── main.py
├── requirements.txt
├── triplet_image_loader.py
└── tripletnet.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
BSD 3-Clause License

Copyright (c) 2017, Andreas Veit
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# Conditional Similarity Networks (CSNs)

This repository contains a [PyTorch](http://pytorch.org/) implementation of the paper [Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) presented at CVPR 2017. 

The code is based on the [PyTorch example for training ResNet on Imagenet](https://github.com/pytorch/examples/tree/master/imagenet) and the [Triplet Network example](https://github.com/andreasveit/triplet-network-pytorch).

## Table of Contents
0. [Introduction](#introduction)
0. [Usage](#usage)
0. [Citing](#citing)
0. [Contact](#contact)

## Introduction
What makes images similar? To measure the similarity between images, they are typically embedded in a feature-vector space, in which their distance preserve the relative dissimilarity. However, when learning such similarity embeddings the simplifying assumption is commonly made that images are only compared to one unique measure of similarity.

[Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) address this shortcoming by learning a nonlinear embeddings that gracefully deals with multiple notions of similarity within a shared embedding. Different aspects of similarity are incorporated by assigning responsibility weights to each embedding dimension with respect to each aspect of similarity.

<img src="https://github.com/andreasveit/conditional-similarity-networks/blob/master/images/csn_overview.png?raw=true" width="600">

Images are passed through a convolutional network and projected into a nonlinear embedding such that different dimensions encode features for specific notions of similarity. Subsequent masks indicate which dimensions of the embedding are responsible for separate aspects of similarity. We can then compare objects according to various notions of similarity by selecting an appropriate masked subspace.

## Usage
The detault setting for this repo is a CSN with fixed masks, an embedding dimension 64 and four notions of similarity.

You can download the Zappos dataset as well as the training, validation and test triplets used in the paper with

```sh
python get_data.py
```

The network can be simply trained with `python main.py` or with optional arguments for different hyperparameters:
```sh
$ python main.py --name {your experiment name} --learned --num_traintriplets 200000
```

Training progress can be easily tracked with [visdom](https://github.com/facebookresearch/visdom) using the `--visdom` flag. It keeps track of the learning rate, loss, training and validation accuracy both for all triplets as well as separated for each notion of similarity, the embedding norm, mask norm as well as the masks.

<img src="https://github.com/andreasveit/conditional-similarity-networks/blob/master/images/visdom.png?raw=true" width="500">

By default the training code keeps track of the model with the highest performance on the validation set. Thus, after the model has converged, it can be directly evaluated on the test set as follows
```sh
$ python main.py --test --resume runs/{your experiment name}/model_best.pth.tar
```

## Citing
If you find this helps your research, please consider citing:

```
@conference{Veit2017,
title = {Conditional Similarity Networks},
author = {Andreas Veit and Serge Belongie and Theofanis Karaletsos},
year = {2017},
journal = {Computer Vision and Pattern Recognition (CVPR)},
}
```

## Contact
andreas at cs dot cornell dot edu 

Any discussions, suggestions and questions are welcome!


================================================
FILE: Resnet_18.py
================================================
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['ResNet', 'resnet18']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, embedding_size=64):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(7)
        self.fc_embed = nn.Linear(256 * block.expansion, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc_embed(x)

        return x


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2], **kwargs)
    if pretrained:
        state = model.state_dict()
        loaded_state_dict = model_zoo.load_url(model_urls['resnet18'])
        for k in loaded_state_dict:
            if k in state:
                state[k] = loaded_state_dict[k]
        model.load_state_dict(state)
    return model


================================================
FILE: csn.py
================================================
import torch
import torch.nn as nn
import numpy as np

class ConditionalSimNet(nn.Module):
    def __init__(self, embeddingnet, n_conditions, embedding_size, learnedmask=True, prein=False):
        """ embeddingnet: The network that projects the inputs into an embedding of embedding_size
            n_conditions: Integer defining number of different similarity notions
            embedding_size: Number of dimensions of the embedding output from the embeddingnet
            learnedmask: Boolean indicating whether masks are learned or fixed
            prein: Boolean indicating whether masks are initialized in equally sized disjoint 
                sections or random otherwise"""
        super(ConditionalSimNet, self).__init__()
        self.learnedmask = learnedmask
        self.embeddingnet = embeddingnet
        # create the mask
        if learnedmask:
            if prein:
                # define masks 
                self.masks = torch.nn.Embedding(n_conditions, embedding_size)
                # initialize masks
                mask_array = np.zeros([n_conditions, embedding_size])
                mask_array.fill(0.1)
                mask_len = int(embedding_size / n_conditions)
                for i in range(n_conditions):
                    mask_array[i, i*mask_len:(i+1)*mask_len] = 1
                # no gradients for the masks
                self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=True)
            else:
                # define masks with gradients
                self.masks = torch.nn.Embedding(n_conditions, embedding_size)
                # initialize weights
                self.masks.weight.data.normal_(0.9, 0.7) # 0.1, 0.005
        else:
            # define masks 
            self.masks = torch.nn.Embedding(n_conditions, embedding_size)
            # initialize masks
            mask_array = np.zeros([n_conditions, embedding_size])
            mask_len = int(embedding_size / n_conditions)
            for i in range(n_conditions):
                mask_array[i, i*mask_len:(i+1)*mask_len] = 1
            # no gradients for the masks
            self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=False)
    def forward(self, x, c):
        embedded_x = self.embeddingnet(x)
        self.mask = self.masks(c)
        if self.learnedmask:
            self.mask = torch.nn.functional.relu(self.mask)
        masked_embedding = embedded_x * self.mask
        return masked_embedding, self.mask.norm(1), embedded_x.norm(2), masked_embedding.norm(2)

================================================
FILE: get_data.py
================================================
import urllib.request
import os
import os.path
import zipfile

if not os.path.exists(os.path.join('data')):
    os.makedirs('data')

if os.path.exists(os.path.join('data', 'ut-zap50k-images')):
    pass
else:
    urllib.request.urlretrieve("http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip", filename="data/ut-zap50k-imgs.zip")
    zip_ref = zipfile.ZipFile("data/ut-zap50k-imgs.zip", 'r')
    zip_ref.extractall("data")
    zip_ref.close()
    os.remove("data/ut-zap50k-imgs.zip")

if os.path.exists(os.path.join('data', 'tripletlists')):
    pass
else:    
    urllib.request.urlretrieve("https://vision.cornell.edu/se3/wp-content/uploads/2019/05/csn_zappos_triplets.zip", filename="data/triplets.zip")
    zip_ref = zipfile.ZipFile("data/triplets.zip", 'r')
    zip_ref.extractall("data")
    zip_ref.close()
    os.remove("data/triplets.zip")


================================================
FILE: main.py
================================================
from __future__ import print_function
import argparse
import os
import sys
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from triplet_image_loader import TripletImageLoader
from tripletnet import CS_Tripletnet
from visdom import Visdom
import numpy as np
import Resnet_18
from csn import ConditionalSimNet

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 200)')
parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
                    help='number of start epoch (default: 1)')
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
                    help='learning rate (default: 5e-5)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--log-interval', type=int, default=20, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--margin', type=float, default=0.2, metavar='M',
                    help='margin for triplet loss (default: 0.2)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='Conditional_Similarity_Network', type=str,
                    help='name of experiment')
parser.add_argument('--embed_loss', type=float, default=5e-3, metavar='M',
                    help='parameter for loss for embedding norm')
parser.add_argument('--mask_loss', type=float, default=5e-4, metavar='M',
                    help='parameter for loss for mask norm')
parser.add_argument('--num_traintriplets', type=int, default=100000, metavar='N',
                    help='how many unique training triplets (default: 100000)')
parser.add_argument('--dim_embed', type=int, default=64, metavar='N',
                    help='how many dimensions in embedding (default: 64)')
parser.add_argument('--test', dest='test', action='store_true',
                    help='To only run inference on test set')
parser.add_argument('--learned', dest='learned', action='store_true',
                    help='To learn masks from random initialization')
parser.add_argument('--prein', dest='prein', action='store_true',
                    help='To initialize masks to be disjoint')
parser.add_argument('--visdom', dest='visdom', action='store_true',
                    help='Use visdom to track and plot')
parser.add_argument('--conditions', nargs='*', type=int,
                    help='Set of similarity notions')
parser.set_defaults(test=False)
parser.set_defaults(learned=False)
parser.set_defaults(prein=False)
parser.set_defaults(visdom=False)

best_acc = 0

def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    if args.visdom:
        global plotter 
        plotter = VisdomLinePlotter(env_name=args.name)
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    global conditions
    if args.conditions is not None:
        conditions = args.conditions
    else:
        conditions = [0,1,2,3]
    
    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', 
            conditions, 'train', n_triplets=args.num_traintriplets,
                        transform=transforms.Compose([
                            transforms.Resize(112),
                            transforms.CenterCrop(112),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize,
                    ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', 
            conditions, 'test', n_triplets=160000,
                        transform=transforms.Compose([
                            transforms.Resize(112),
                            transforms.CenterCrop(112),
                            transforms.ToTensor(),
                            normalize,
                    ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', 
            conditions, 'val', n_triplets=80000,
                        transform=transforms.Compose([
                            transforms.Resize(112),
                            transforms.CenterCrop(112),
                            transforms.ToTensor(),
                            normalize,
                    ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    
    model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
    csn_model = ConditionalSimNet(model, n_conditions=len(conditions), 
        embedding_size=args.dim_embed, learnedmask=args.learned, prein=args.prein)
    global mask_var
    mask_var = csn_model.masks.weight
    tnet = CS_Tripletnet(csn_model)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin = args.margin)
    parameters = filter(lambda p: p.requires_grad, tnet.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    if args.test:
        test_acc = test(test_loader, tnet, criterion, 1)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs + 1):
        # update learning rate
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(val_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': tnet.state_dict(),
            'best_prec1': best_acc,
        }, is_best)

def train(train_loader, tnet, criterion, optimizer, epoch):
    losses = AverageMeter()
    accs = AverageMeter()
    emb_norms = AverageMeter()
    mask_norms = AverageMeter()

    # switch to train mode
    tnet.train()
    for batch_idx, (data1, data2, data3, c) in enumerate(train_loader):
        if args.cuda:
            data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda()
        data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c)

        # compute output
        dista, distb, mask_norm, embed_norm, mask_embed_norm = tnet(data1, data2, data3, c)
        # 1 means, dista should be larger than distb
        target = torch.FloatTensor(dista.size()).fill_(1)
        if args.cuda:
            target = target.cuda()
        target = Variable(target)
        
        loss_triplet = criterion(dista, distb, target)
        loss_embedd = embed_norm / np.sqrt(data1.size(0))
        loss_mask = mask_norm / data1.size(0)
        loss = loss_triplet + args.embed_loss * loss_embedd + args.mask_loss * loss_mask

        # measure accuracy and record loss
        acc = accuracy(dista, distb)
        losses.update(loss_triplet.data.item(), data1.size(0))
        accs.update(acc, data1.size(0))
        emb_norms.update(loss_embedd.data.item())
        mask_norms.update(loss_mask.data.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{}]\t'
                  'Loss: {:.4f} ({:.4f}) \t'
                  'Acc: {:.2f}% ({:.2f}%) \t'
                  'Emb_Norm: {:.2f} ({:.2f})'.format(
                epoch, batch_idx * len(data1), len(train_loader.dataset),
                losses.val, losses.avg, 
                100. * accs.val, 100. * accs.avg, emb_norms.val, emb_norms.avg))

    # log avg values to visdom
    if args.visdom:
        plotter.plot('acc', 'train', epoch, accs.avg)
        plotter.plot('loss', 'train', epoch, losses.avg)
        plotter.plot('emb_norms', 'train', epoch, emb_norms.avg)
        plotter.plot('mask_norms', 'train', epoch, mask_norms.avg)
        if epoch % 10 == 0:
            plotter.plot_mask(torch.nn.functional.relu(mask_var).data.cpu().numpy().T, epoch)

def test(test_loader, tnet, criterion, epoch):
    losses = AverageMeter()
    accs = AverageMeter()
    accs_cs = {}
    for condition in conditions:
        accs_cs[condition] = AverageMeter()

    # switch to evaluation mode
    tnet.eval()
    for batch_idx, (data1, data2, data3, c) in enumerate(test_loader):
        if args.cuda:
            data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda()
        data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c)
        c_test = c

        # compute output
        dista, distb, _, _, _ = tnet(data1, data2, data3, c)
        target = torch.FloatTensor(dista.size()).fill_(1)
        if args.cuda:
            target = target.cuda()
        target = Variable(target)
        test_loss =  criterion(dista, distb, target).data.item()

        # measure accuracy and record loss
        acc = accuracy(dista, distb)
        accs.update(acc, data1.size(0))
        for condition in conditions:
            accs_cs[condition].update(accuracy_id(dista, distb, c_test, condition), data1.size(0))
        losses.update(test_loss, data1.size(0))      

    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
        losses.avg, 100. * accs.avg))
    if args.visdom:
        for condition in conditions:
            plotter.plot('accs', 'acc_{}'.format(condition), epoch, accs_cs[condition].avg)
        plotter.plot(args.name, args.name, epoch, accs.avg, env='overview')
        plotter.plot('acc', 'test', epoch, accs.avg)
        plotter.plot('loss', 'test', epoch, losses.avg)
    return accs.avg

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """Saves checkpoint to disk"""
    directory = "runs/%s/"%(args.name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar')

class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}
    def plot(self, var_name, split_name, x, y, env=None):
        if env is not None:
            print_env = env
        else:
            print_env = self.env
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=print_env, opts=dict(
                legend=[split_name],
                title=var_name,
                xlabel='Epochs',
                ylabel=var_name
            ))
        else:
            self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name)
    def plot_mask(self, masks, epoch):
        self.viz.bar(
            X=masks,
            env=self.env,
            opts=dict(
                stacked=True,
                title=epoch,
            )
        )

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * ((1 - 0.015) ** epoch)
    if args.visdom:
        plotter.plot('lr', 'learning rate', epoch, lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def accuracy(dista, distb):
    margin = 0
    pred = (dista - distb - margin).cpu().data
    return (pred > 0).sum()*1.0/dista.size()[0]

def accuracy_id(dista, distb, c, c_id):
    margin = 0
    pred = (dista - distb - margin).cpu().data
    return ((pred > 0)*(c.cpu().data == c_id)).sum()*1.0/(c.cpu().data == c_id).sum()

if __name__ == '__main__':
    main()    


================================================
FILE: requirements.txt
================================================
torch
torchvision
visdom


================================================
FILE: triplet_image_loader.py
================================================
from PIL import Image
import os
import os.path
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np

filenames = {'train': ['class_tripletlist_train.txt', 'closure_tripletlist_train.txt', 
                'gender_tripletlist_train.txt', 'heel_tripletlist_train.txt'],
             'val': ['class_tripletlist_val.txt', 'closure_tripletlist_val.txt', 
                'gender_tripletlist_val.txt', 'heel_tripletlist_val.txt'],
             'test': ['class_tripletlist_test.txt', 'closure_tripletlist_test.txt', 
                'gender_tripletlist_test.txt', 'heel_tripletlist_test.txt']}

def default_image_loader(path):
    return Image.open(path).convert('RGB')

class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self, root, base_path, filenames_filename, conditions, split, n_triplets, transform=None,
                 loader=default_image_loader):
        """ filenames_filename: A text file with each line containing the path to an image e.g.,
                images/class1/sample.jpg
            triplets_file_name: A text file with each line containing three integers, 
                where integer i refers to the i-th image in the filenames file. 
                For a line of intergers 'a b c', a triplet is defined such that image a is more 
                similar to image c than it is to image b, e.g., 
                0 2017 42 """
        self.root = root
        self.base_path = base_path  
        self.filenamelist = []
        for line in open(os.path.join(self.root, filenames_filename)):
            self.filenamelist.append(line.rstrip('\n'))
        triplets = []
        if split == 'train':
            fnames = filenames['train']
        elif split == 'val':
            fnames = filenames['val']
        else:
            fnames = filenames['test']
        for condition in conditions:
            for line in open(os.path.join(self.root, 'tripletlists', fnames[condition])):
                triplets.append((line.split()[0], line.split()[1], line.split()[2], condition)) # anchor, far, close   
        # print(triplets[:100])   
        np.random.shuffle(triplets)
        # print(triplets[:100])  
        self.triplets = triplets[:int(n_triplets * 1.0 * len(conditions) / 4)]
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        path1, path2, path3, c = self.triplets[index]
        if os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])):
            img1 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)]))
            img2 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path2)]))
            img3 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path3)]))
            if self.transform is not None:
                img1 = self.transform(img1)
                img2 = self.transform(img2)
                img3 = self.transform(img3)
            return img1, img2, img3, c
        else:
            return None

    def __len__(self):
        return len(self.triplets)


================================================
FILE: tripletnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class CS_Tripletnet(nn.Module):
    def __init__(self, embeddingnet):
        super(CS_Tripletnet, self).__init__()
        self.embeddingnet = embeddingnet

    def forward(self, x, y, z, c):
        """ x: Anchor image,
            y: Distant (negative) image,
            z: Close (positive) image,
            c: Integer indicating according to which notion of similarity images are compared"""
        embedded_x, masknorm_norm_x, embed_norm_x, tot_embed_norm_x = self.embeddingnet(x, c)
        embedded_y, masknorm_norm_y, embed_norm_y, tot_embed_norm_y = self.embeddingnet(y, c)
        embedded_z, masknorm_norm_z, embed_norm_z, tot_embed_norm_z = self.embeddingnet(z, c)
        mask_norm = (masknorm_norm_x + masknorm_norm_y + masknorm_norm_z) / 3
        embed_norm = (embed_norm_x + embed_norm_y + embed_norm_z) / 3
        mask_embed_norm = (tot_embed_norm_x + tot_embed_norm_y + tot_embed_norm_z) / 3
        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
        return dist_a, dist_b, mask_norm, embed_norm, mask_embed_norm
Download .txt
gitextract_xmmzkyrb/

├── LICENSE
├── README.md
├── Resnet_18.py
├── csn.py
├── get_data.py
├── main.py
├── requirements.txt
├── triplet_image_loader.py
└── tripletnet.py
Download .txt
SYMBOL INDEX (35 symbols across 5 files)

FILE: Resnet_18.py
  function conv3x3 (line 14) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 20) | class BasicBlock(nn.Module):
    method __init__ (line 23) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 33) | def forward(self, x):
  class ResNet (line 51) | class ResNet(nn.Module):
    method __init__ (line 53) | def __init__(self, block, layers, embedding_size=64):
    method _make_layer (line 75) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 92) | def forward(self, x):
  function resnet18 (line 109) | def resnet18(pretrained=False, **kwargs):

FILE: csn.py
  class ConditionalSimNet (line 5) | class ConditionalSimNet(nn.Module):
    method __init__ (line 6) | def __init__(self, embeddingnet, n_conditions, embedding_size, learned...
    method forward (line 44) | def forward(self, x, c):

FILE: main.py
  function main (line 67) | def main():
  function train (line 172) | def train(train_loader, tnet, criterion, optimizer, epoch):
  function test (line 228) | def test(test_loader, tnet, criterion, epoch):
  function save_checkpoint (line 268) | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
  class VisdomLinePlotter (line 278) | class VisdomLinePlotter(object):
    method __init__ (line 280) | def __init__(self, env_name='main'):
    method plot (line 284) | def plot(self, var_name, split_name, x, y, env=None):
    method plot_mask (line 298) | def plot_mask(self, masks, epoch):
  class AverageMeter (line 308) | class AverageMeter(object):
    method __init__ (line 310) | def __init__(self):
    method reset (line 313) | def reset(self):
    method update (line 319) | def update(self, val, n=1):
  function adjust_learning_rate (line 325) | def adjust_learning_rate(optimizer, epoch):
  function accuracy (line 333) | def accuracy(dista, distb):
  function accuracy_id (line 338) | def accuracy_id(dista, distb, c, c_id):

FILE: triplet_image_loader.py
  function default_image_loader (line 15) | def default_image_loader(path):
  class TripletImageLoader (line 18) | class TripletImageLoader(torch.utils.data.Dataset):
    method __init__ (line 19) | def __init__(self, root, base_path, filenames_filename, conditions, sp...
    method __getitem__ (line 50) | def __getitem__(self, index):
    method __len__ (line 64) | def __len__(self):

FILE: tripletnet.py
  class CS_Tripletnet (line 5) | class CS_Tripletnet(nn.Module):
    method __init__ (line 6) | def __init__(self, embeddingnet):
    method forward (line 10) | def forward(self, x, y, z, c):
Condensed preview — 9 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (32K chars).
[
  {
    "path": "LICENSE",
    "chars": 1512,
    "preview": "BSD 3-Clause License\n\nCopyright (c) 2017, Andreas Veit\nAll rights reserved.\n\nRedistribution and use in source and binary"
  },
  {
    "path": "README.md",
    "chars": 3457,
    "preview": "# Conditional Similarity Networks (CSNs)\n\nThis repository contains a [PyTorch](http://pytorch.org/) implementation of th"
  },
  {
    "path": "Resnet_18.py",
    "chars": 3796,
    "preview": "import torch.nn as nn\nimport math\nimport torch.utils.model_zoo as model_zoo\n\n\n__all__ = ['ResNet', 'resnet18']\n\n\nmodel_u"
  },
  {
    "path": "csn.py",
    "chars": 2563,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nclass ConditionalSimNet(nn.Module):\n    def __init__(self, embedd"
  },
  {
    "path": "get_data.py",
    "chars": 877,
    "preview": "import urllib.request\nimport os\nimport os.path\nimport zipfile\n\nif not os.path.exists(os.path.join('data')):\n    os.maked"
  },
  {
    "path": "main.py",
    "chars": 13955,
    "preview": "from __future__ import print_function\nimport argparse\nimport os\nimport sys\nimport shutil\nimport torch\nimport torch.nn as"
  },
  {
    "path": "requirements.txt",
    "chars": 25,
    "preview": "torch\ntorchvision\nvisdom\n"
  },
  {
    "path": "triplet_image_loader.py",
    "chars": 3317,
    "preview": "from PIL import Image\nimport os\nimport os.path\nimport torch.utils.data\nimport torchvision.transforms as transforms\nimpor"
  },
  {
    "path": "tripletnet.py",
    "chars": 1181,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CS_Tripletnet(nn.Module):\n    def __init__(sel"
  }
]

About this extraction

This page contains the full source code of the andreasveit/conditional-similarity-networks GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 9 files (30.0 KB), approximately 7.4k tokens, and a symbol index with 35 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!