[
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2017, Andreas Veit\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "# Conditional Similarity Networks (CSNs)\n\nThis repository contains a [PyTorch](http://pytorch.org/) implementation of the paper [Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) presented at CVPR 2017. \n\nThe code is based on the [PyTorch example for training ResNet on Imagenet](https://github.com/pytorch/examples/tree/master/imagenet) and the [Triplet Network example](https://github.com/andreasveit/triplet-network-pytorch).\n\n## Table of Contents\n0. [Introduction](#introduction)\n0. [Usage](#usage)\n0. [Citing](#citing)\n0. [Contact](#contact)\n\n## Introduction\nWhat makes images similar? To measure the similarity between images, they are typically embedded in a feature-vector space, in which their distance preserve the relative dissimilarity. However, when learning such similarity embeddings the simplifying assumption is commonly made that images are only compared to one unique measure of similarity.\n\n[Conditional Similarity Networks](https://arxiv.org/abs/1603.07810) address this shortcoming by learning a nonlinear embeddings that gracefully deals with multiple notions of similarity within a shared embedding. Different aspects of similarity are incorporated by assigning responsibility weights to each embedding dimension with respect to each aspect of similarity.\n\n<img src=\"https://github.com/andreasveit/conditional-similarity-networks/blob/master/images/csn_overview.png?raw=true\" width=\"600\">\n\nImages are passed through a convolutional network and projected into a nonlinear embedding such that different dimensions encode features for specific notions of similarity. Subsequent masks indicate which dimensions of the embedding are responsible for separate aspects of similarity. We can then compare objects according to various notions of similarity by selecting an appropriate masked subspace.\n\n## Usage\nThe detault setting for this repo is a CSN with fixed masks, an embedding dimension 64 and four notions of similarity.\n\nYou can download the Zappos dataset as well as the training, validation and test triplets used in the paper with\n\n```sh\npython get_data.py\n```\n\nThe network can be simply trained with `python main.py` or with optional arguments for different hyperparameters:\n```sh\n$ python main.py --name {your experiment name} --learned --num_traintriplets 200000\n```\n\nTraining progress can be easily tracked with [visdom](https://github.com/facebookresearch/visdom) using the `--visdom` flag. It keeps track of the learning rate, loss, training and validation accuracy both for all triplets as well as separated for each notion of similarity, the embedding norm, mask norm as well as the masks.\n\n<img src=\"https://github.com/andreasveit/conditional-similarity-networks/blob/master/images/visdom.png?raw=true\" width=\"500\">\n\nBy default the training code keeps track of the model with the highest performance on the validation set. Thus, after the model has converged, it can be directly evaluated on the test set as follows\n```sh\n$ python main.py --test --resume runs/{your experiment name}/model_best.pth.tar\n```\n\n## Citing\nIf you find this helps your research, please consider citing:\n\n```\n@conference{Veit2017,\ntitle = {Conditional Similarity Networks},\nauthor = {Andreas Veit and Serge Belongie and Theofanis Karaletsos},\nyear = {2017},\njournal = {Computer Vision and Pattern Recognition (CVPR)},\n}\n```\n\n## Contact\nandreas at cs dot cornell dot edu \n\nAny discussions, suggestions and questions are welcome!\n"
  },
  {
    "path": "Resnet_18.py",
    "content": "import torch.nn as nn\nimport math\nimport torch.utils.model_zoo as model_zoo\n\n\n__all__ = ['ResNet', 'resnet18']\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, embedding_size=64):\n        self.inplanes = 64\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.avgpool = nn.AvgPool2d(7)\n        self.fc_embed = nn.Linear(256 * block.expansion, embedding_size)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc_embed(x)\n\n        return x\n\n\ndef resnet18(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-18 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = ResNet(BasicBlock, [2, 2, 2], **kwargs)\n    if pretrained:\n        state = model.state_dict()\n        loaded_state_dict = model_zoo.load_url(model_urls['resnet18'])\n        for k in loaded_state_dict:\n            if k in state:\n                state[k] = loaded_state_dict[k]\n        model.load_state_dict(state)\n    return model\n"
  },
  {
    "path": "csn.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nclass ConditionalSimNet(nn.Module):\n    def __init__(self, embeddingnet, n_conditions, embedding_size, learnedmask=True, prein=False):\n        \"\"\" embeddingnet: The network that projects the inputs into an embedding of embedding_size\n            n_conditions: Integer defining number of different similarity notions\n            embedding_size: Number of dimensions of the embedding output from the embeddingnet\n            learnedmask: Boolean indicating whether masks are learned or fixed\n            prein: Boolean indicating whether masks are initialized in equally sized disjoint \n                sections or random otherwise\"\"\"\n        super(ConditionalSimNet, self).__init__()\n        self.learnedmask = learnedmask\n        self.embeddingnet = embeddingnet\n        # create the mask\n        if learnedmask:\n            if prein:\n                # define masks \n                self.masks = torch.nn.Embedding(n_conditions, embedding_size)\n                # initialize masks\n                mask_array = np.zeros([n_conditions, embedding_size])\n                mask_array.fill(0.1)\n                mask_len = int(embedding_size / n_conditions)\n                for i in range(n_conditions):\n                    mask_array[i, i*mask_len:(i+1)*mask_len] = 1\n                # no gradients for the masks\n                self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=True)\n            else:\n                # define masks with gradients\n                self.masks = torch.nn.Embedding(n_conditions, embedding_size)\n                # initialize weights\n                self.masks.weight.data.normal_(0.9, 0.7) # 0.1, 0.005\n        else:\n            # define masks \n            self.masks = torch.nn.Embedding(n_conditions, embedding_size)\n            # initialize masks\n            mask_array = np.zeros([n_conditions, embedding_size])\n            mask_len = int(embedding_size / n_conditions)\n            for i in range(n_conditions):\n                mask_array[i, i*mask_len:(i+1)*mask_len] = 1\n            # no gradients for the masks\n            self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=False)\n    def forward(self, x, c):\n        embedded_x = self.embeddingnet(x)\n        self.mask = self.masks(c)\n        if self.learnedmask:\n            self.mask = torch.nn.functional.relu(self.mask)\n        masked_embedding = embedded_x * self.mask\n        return masked_embedding, self.mask.norm(1), embedded_x.norm(2), masked_embedding.norm(2)"
  },
  {
    "path": "get_data.py",
    "content": "import urllib.request\nimport os\nimport os.path\nimport zipfile\n\nif not os.path.exists(os.path.join('data')):\n    os.makedirs('data')\n\nif os.path.exists(os.path.join('data', 'ut-zap50k-images')):\n    pass\nelse:\n    urllib.request.urlretrieve(\"http://vision.cs.utexas.edu/projects/finegrained/utzap50k/ut-zap50k-images.zip\", filename=\"data/ut-zap50k-imgs.zip\")\n    zip_ref = zipfile.ZipFile(\"data/ut-zap50k-imgs.zip\", 'r')\n    zip_ref.extractall(\"data\")\n    zip_ref.close()\n    os.remove(\"data/ut-zap50k-imgs.zip\")\n\nif os.path.exists(os.path.join('data', 'tripletlists')):\n    pass\nelse:    \n    urllib.request.urlretrieve(\"https://vision.cornell.edu/se3/wp-content/uploads/2019/05/csn_zappos_triplets.zip\", filename=\"data/triplets.zip\")\n    zip_ref = zipfile.ZipFile(\"data/triplets.zip\", 'r')\n    zip_ref.extractall(\"data\")\n    zip_ref.close()\n    os.remove(\"data/triplets.zip\")\n"
  },
  {
    "path": "main.py",
    "content": "from __future__ import print_function\nimport argparse\nimport os\nimport sys\nimport shutil\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torchvision import transforms\nfrom torch.autograd import Variable\nimport torch.backends.cudnn as cudnn\nfrom triplet_image_loader import TripletImageLoader\nfrom tripletnet import CS_Tripletnet\nfrom visdom import Visdom\nimport numpy as np\nimport Resnet_18\nfrom csn import ConditionalSimNet\n\n# Training settings\nparser = argparse.ArgumentParser(description='PyTorch MNIST Example')\nparser.add_argument('--batch-size', type=int, default=256, metavar='N',\n                    help='input batch size for training (default: 64)')\nparser.add_argument('--epochs', type=int, default=200, metavar='N',\n                    help='number of epochs to train (default: 200)')\nparser.add_argument('--start_epoch', type=int, default=1, metavar='N',\n                    help='number of start epoch (default: 1)')\nparser.add_argument('--lr', type=float, default=5e-5, metavar='LR',\n                    help='learning rate (default: 5e-5)')\nparser.add_argument('--seed', type=int, default=1, metavar='S',\n                    help='random seed (default: 1)')\nparser.add_argument('--no-cuda', action='store_true', default=False,\n                    help='enables CUDA training')\nparser.add_argument('--log-interval', type=int, default=20, metavar='N',\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--margin', type=float, default=0.2, metavar='M',\n                    help='margin for triplet loss (default: 0.2)')\nparser.add_argument('--resume', default='', type=str,\n                    help='path to latest checkpoint (default: none)')\nparser.add_argument('--name', default='Conditional_Similarity_Network', type=str,\n                    help='name of experiment')\nparser.add_argument('--embed_loss', type=float, default=5e-3, metavar='M',\n                    help='parameter for loss for embedding norm')\nparser.add_argument('--mask_loss', type=float, default=5e-4, metavar='M',\n                    help='parameter for loss for mask norm')\nparser.add_argument('--num_traintriplets', type=int, default=100000, metavar='N',\n                    help='how many unique training triplets (default: 100000)')\nparser.add_argument('--dim_embed', type=int, default=64, metavar='N',\n                    help='how many dimensions in embedding (default: 64)')\nparser.add_argument('--test', dest='test', action='store_true',\n                    help='To only run inference on test set')\nparser.add_argument('--learned', dest='learned', action='store_true',\n                    help='To learn masks from random initialization')\nparser.add_argument('--prein', dest='prein', action='store_true',\n                    help='To initialize masks to be disjoint')\nparser.add_argument('--visdom', dest='visdom', action='store_true',\n                    help='Use visdom to track and plot')\nparser.add_argument('--conditions', nargs='*', type=int,\n                    help='Set of similarity notions')\nparser.set_defaults(test=False)\nparser.set_defaults(learned=False)\nparser.set_defaults(prein=False)\nparser.set_defaults(visdom=False)\n\nbest_acc = 0\n\ndef main():\n    global args, best_acc\n    args = parser.parse_args()\n    args.cuda = not args.no_cuda and torch.cuda.is_available()\n    torch.manual_seed(args.seed)\n    if args.cuda:\n        torch.cuda.manual_seed(args.seed)\n    if args.visdom:\n        global plotter \n        plotter = VisdomLinePlotter(env_name=args.name)\n    \n    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n                                     std=[0.229, 0.224, 0.225])\n\n    global conditions\n    if args.conditions is not None:\n        conditions = args.conditions\n    else:\n        conditions = [0,1,2,3]\n    \n    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}\n    train_loader = torch.utils.data.DataLoader(\n        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', \n            conditions, 'train', n_triplets=args.num_traintriplets,\n                        transform=transforms.Compose([\n                            transforms.Resize(112),\n                            transforms.CenterCrop(112),\n                            transforms.RandomHorizontalFlip(),\n                            transforms.ToTensor(),\n                            normalize,\n                    ])),\n        batch_size=args.batch_size, shuffle=True, **kwargs)\n    test_loader = torch.utils.data.DataLoader(\n        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', \n            conditions, 'test', n_triplets=160000,\n                        transform=transforms.Compose([\n                            transforms.Resize(112),\n                            transforms.CenterCrop(112),\n                            transforms.ToTensor(),\n                            normalize,\n                    ])),\n        batch_size=args.batch_size, shuffle=True, **kwargs)\n    val_loader = torch.utils.data.DataLoader(\n        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json', \n            conditions, 'val', n_triplets=80000,\n                        transform=transforms.Compose([\n                            transforms.Resize(112),\n                            transforms.CenterCrop(112),\n                            transforms.ToTensor(),\n                            normalize,\n                    ])),\n        batch_size=args.batch_size, shuffle=True, **kwargs)\n    \n    model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)\n    csn_model = ConditionalSimNet(model, n_conditions=len(conditions), \n        embedding_size=args.dim_embed, learnedmask=args.learned, prein=args.prein)\n    global mask_var\n    mask_var = csn_model.masks.weight\n    tnet = CS_Tripletnet(csn_model)\n    if args.cuda:\n        tnet.cuda()\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        if os.path.isfile(args.resume):\n            print(\"=> loading checkpoint '{}'\".format(args.resume))\n            checkpoint = torch.load(args.resume)\n            args.start_epoch = checkpoint['epoch']\n            best_prec1 = checkpoint['best_prec1']\n            tnet.load_state_dict(checkpoint['state_dict'])\n            print(\"=> loaded checkpoint '{}' (epoch {})\"\n                    .format(args.resume, checkpoint['epoch']))\n        else:\n            print(\"=> no checkpoint found at '{}'\".format(args.resume))\n\n    cudnn.benchmark = True\n\n    criterion = torch.nn.MarginRankingLoss(margin = args.margin)\n    parameters = filter(lambda p: p.requires_grad, tnet.parameters())\n    optimizer = optim.Adam(parameters, lr=args.lr)\n\n    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])\n    print('  + Number of params: {}'.format(n_parameters))\n\n    if args.test:\n        test_acc = test(test_loader, tnet, criterion, 1)\n        sys.exit()\n\n    for epoch in range(args.start_epoch, args.epochs + 1):\n        # update learning rate\n        adjust_learning_rate(optimizer, epoch)\n        # train for one epoch\n        train(train_loader, tnet, criterion, optimizer, epoch)\n        # evaluate on validation set\n        acc = test(val_loader, tnet, criterion, epoch)\n\n        # remember best acc and save checkpoint\n        is_best = acc > best_acc\n        best_acc = max(acc, best_acc)\n        save_checkpoint({\n            'epoch': epoch + 1,\n            'state_dict': tnet.state_dict(),\n            'best_prec1': best_acc,\n        }, is_best)\n\ndef train(train_loader, tnet, criterion, optimizer, epoch):\n    losses = AverageMeter()\n    accs = AverageMeter()\n    emb_norms = AverageMeter()\n    mask_norms = AverageMeter()\n\n    # switch to train mode\n    tnet.train()\n    for batch_idx, (data1, data2, data3, c) in enumerate(train_loader):\n        if args.cuda:\n            data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda()\n        data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c)\n\n        # compute output\n        dista, distb, mask_norm, embed_norm, mask_embed_norm = tnet(data1, data2, data3, c)\n        # 1 means, dista should be larger than distb\n        target = torch.FloatTensor(dista.size()).fill_(1)\n        if args.cuda:\n            target = target.cuda()\n        target = Variable(target)\n        \n        loss_triplet = criterion(dista, distb, target)\n        loss_embedd = embed_norm / np.sqrt(data1.size(0))\n        loss_mask = mask_norm / data1.size(0)\n        loss = loss_triplet + args.embed_loss * loss_embedd + args.mask_loss * loss_mask\n\n        # measure accuracy and record loss\n        acc = accuracy(dista, distb)\n        losses.update(loss_triplet.data.item(), data1.size(0))\n        accs.update(acc, data1.size(0))\n        emb_norms.update(loss_embedd.data.item())\n        mask_norms.update(loss_mask.data.item())\n\n        # compute gradient and do optimizer step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if batch_idx % args.log_interval == 0:\n            print('Train Epoch: {} [{}/{}]\\t'\n                  'Loss: {:.4f} ({:.4f}) \\t'\n                  'Acc: {:.2f}% ({:.2f}%) \\t'\n                  'Emb_Norm: {:.2f} ({:.2f})'.format(\n                epoch, batch_idx * len(data1), len(train_loader.dataset),\n                losses.val, losses.avg, \n                100. * accs.val, 100. * accs.avg, emb_norms.val, emb_norms.avg))\n\n    # log avg values to visdom\n    if args.visdom:\n        plotter.plot('acc', 'train', epoch, accs.avg)\n        plotter.plot('loss', 'train', epoch, losses.avg)\n        plotter.plot('emb_norms', 'train', epoch, emb_norms.avg)\n        plotter.plot('mask_norms', 'train', epoch, mask_norms.avg)\n        if epoch % 10 == 0:\n            plotter.plot_mask(torch.nn.functional.relu(mask_var).data.cpu().numpy().T, epoch)\n\ndef test(test_loader, tnet, criterion, epoch):\n    losses = AverageMeter()\n    accs = AverageMeter()\n    accs_cs = {}\n    for condition in conditions:\n        accs_cs[condition] = AverageMeter()\n\n    # switch to evaluation mode\n    tnet.eval()\n    for batch_idx, (data1, data2, data3, c) in enumerate(test_loader):\n        if args.cuda:\n            data1, data2, data3, c = data1.cuda(), data2.cuda(), data3.cuda(), c.cuda()\n        data1, data2, data3, c = Variable(data1), Variable(data2), Variable(data3), Variable(c)\n        c_test = c\n\n        # compute output\n        dista, distb, _, _, _ = tnet(data1, data2, data3, c)\n        target = torch.FloatTensor(dista.size()).fill_(1)\n        if args.cuda:\n            target = target.cuda()\n        target = Variable(target)\n        test_loss =  criterion(dista, distb, target).data.item()\n\n        # measure accuracy and record loss\n        acc = accuracy(dista, distb)\n        accs.update(acc, data1.size(0))\n        for condition in conditions:\n            accs_cs[condition].update(accuracy_id(dista, distb, c_test, condition), data1.size(0))\n        losses.update(test_loss, data1.size(0))      \n\n    print('\\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\\n'.format(\n        losses.avg, 100. * accs.avg))\n    if args.visdom:\n        for condition in conditions:\n            plotter.plot('accs', 'acc_{}'.format(condition), epoch, accs_cs[condition].avg)\n        plotter.plot(args.name, args.name, epoch, accs.avg, env='overview')\n        plotter.plot('acc', 'test', epoch, accs.avg)\n        plotter.plot('loss', 'test', epoch, losses.avg)\n    return accs.avg\n\ndef save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):\n    \"\"\"Saves checkpoint to disk\"\"\"\n    directory = \"runs/%s/\"%(args.name)\n    if not os.path.exists(directory):\n        os.makedirs(directory)\n    filename = directory + filename\n    torch.save(state, filename)\n    if is_best:\n        shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar')\n\nclass VisdomLinePlotter(object):\n    \"\"\"Plots to Visdom\"\"\"\n    def __init__(self, env_name='main'):\n        self.viz = Visdom()\n        self.env = env_name\n        self.plots = {}\n    def plot(self, var_name, split_name, x, y, env=None):\n        if env is not None:\n            print_env = env\n        else:\n            print_env = self.env\n        if var_name not in self.plots:\n            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=print_env, opts=dict(\n                legend=[split_name],\n                title=var_name,\n                xlabel='Epochs',\n                ylabel=var_name\n            ))\n        else:\n            self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name)\n    def plot_mask(self, masks, epoch):\n        self.viz.bar(\n            X=masks,\n            env=self.env,\n            opts=dict(\n                stacked=True,\n                title=epoch,\n            )\n        )\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\ndef adjust_learning_rate(optimizer, epoch):\n    \"\"\"Sets the learning rate to the initial LR decayed by 10 every 30 epochs\"\"\"\n    lr = args.lr * ((1 - 0.015) ** epoch)\n    if args.visdom:\n        plotter.plot('lr', 'learning rate', epoch, lr)\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n\ndef accuracy(dista, distb):\n    margin = 0\n    pred = (dista - distb - margin).cpu().data\n    return (pred > 0).sum()*1.0/dista.size()[0]\n\ndef accuracy_id(dista, distb, c, c_id):\n    margin = 0\n    pred = (dista - distb - margin).cpu().data\n    return ((pred > 0)*(c.cpu().data == c_id)).sum()*1.0/(c.cpu().data == c_id).sum()\n\nif __name__ == '__main__':\n    main()    \n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\nvisdom\n"
  },
  {
    "path": "triplet_image_loader.py",
    "content": "from PIL import Image\nimport os\nimport os.path\nimport torch.utils.data\nimport torchvision.transforms as transforms\nimport numpy as np\n\nfilenames = {'train': ['class_tripletlist_train.txt', 'closure_tripletlist_train.txt', \n                'gender_tripletlist_train.txt', 'heel_tripletlist_train.txt'],\n             'val': ['class_tripletlist_val.txt', 'closure_tripletlist_val.txt', \n                'gender_tripletlist_val.txt', 'heel_tripletlist_val.txt'],\n             'test': ['class_tripletlist_test.txt', 'closure_tripletlist_test.txt', \n                'gender_tripletlist_test.txt', 'heel_tripletlist_test.txt']}\n\ndef default_image_loader(path):\n    return Image.open(path).convert('RGB')\n\nclass TripletImageLoader(torch.utils.data.Dataset):\n    def __init__(self, root, base_path, filenames_filename, conditions, split, n_triplets, transform=None,\n                 loader=default_image_loader):\n        \"\"\" filenames_filename: A text file with each line containing the path to an image e.g.,\n                images/class1/sample.jpg\n            triplets_file_name: A text file with each line containing three integers, \n                where integer i refers to the i-th image in the filenames file. \n                For a line of intergers 'a b c', a triplet is defined such that image a is more \n                similar to image c than it is to image b, e.g., \n                0 2017 42 \"\"\"\n        self.root = root\n        self.base_path = base_path  \n        self.filenamelist = []\n        for line in open(os.path.join(self.root, filenames_filename)):\n            self.filenamelist.append(line.rstrip('\\n'))\n        triplets = []\n        if split == 'train':\n            fnames = filenames['train']\n        elif split == 'val':\n            fnames = filenames['val']\n        else:\n            fnames = filenames['test']\n        for condition in conditions:\n            for line in open(os.path.join(self.root, 'tripletlists', fnames[condition])):\n                triplets.append((line.split()[0], line.split()[1], line.split()[2], condition)) # anchor, far, close   \n        # print(triplets[:100])   \n        np.random.shuffle(triplets)\n        # print(triplets[:100])  \n        self.triplets = triplets[:int(n_triplets * 1.0 * len(conditions) / 4)]\n        self.transform = transform\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path1, path2, path3, c = self.triplets[index]\n        if os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])) and os.path.exists(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)])):\n            img1 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path1)]))\n            img2 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path2)]))\n            img3 = self.loader(os.path.join(self.root, self.base_path, self.filenamelist[int(path3)]))\n            if self.transform is not None:\n                img1 = self.transform(img1)\n                img2 = self.transform(img2)\n                img3 = self.transform(img3)\n            return img1, img2, img3, c\n        else:\n            return None\n\n    def __len__(self):\n        return len(self.triplets)\n"
  },
  {
    "path": "tripletnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass CS_Tripletnet(nn.Module):\n    def __init__(self, embeddingnet):\n        super(CS_Tripletnet, self).__init__()\n        self.embeddingnet = embeddingnet\n\n    def forward(self, x, y, z, c):\n        \"\"\" x: Anchor image,\n            y: Distant (negative) image,\n            z: Close (positive) image,\n            c: Integer indicating according to which notion of similarity images are compared\"\"\"\n        embedded_x, masknorm_norm_x, embed_norm_x, tot_embed_norm_x = self.embeddingnet(x, c)\n        embedded_y, masknorm_norm_y, embed_norm_y, tot_embed_norm_y = self.embeddingnet(y, c)\n        embedded_z, masknorm_norm_z, embed_norm_z, tot_embed_norm_z = self.embeddingnet(z, c)\n        mask_norm = (masknorm_norm_x + masknorm_norm_y + masknorm_norm_z) / 3\n        embed_norm = (embed_norm_x + embed_norm_y + embed_norm_z) / 3\n        mask_embed_norm = (tot_embed_norm_x + tot_embed_norm_y + tot_embed_norm_z) / 3\n        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)\n        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)\n        return dist_a, dist_b, mask_norm, embed_norm, mask_embed_norm"
  }
]