[
  {
    "path": "README.md",
    "content": "# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch\n\n[![arXiv](https://img.shields.io/badge/arXiv-2304.03977-b31b1b.svg)](https://arxiv.org/abs/2304.03977)\n\n\n![Training Pipeline](pipeline.png)\n\n\nAuthors: Shengbang Tong*, Yubei Chen*, Yi Ma, Yann LeCun\n\n## Introduction\nThis repository contains the implementation for the paper \"EMP-SSL: Towards Self-Supervised Learning in One Training Epoch.\" The paper introduces a simplistic but efficient self-supervised learning method called Extreme-Multi-Patch Self-Supervised-Learning (EMP-SSL). EMP-SSL significantly reduces the training epochs required for convergence by increasing the number of fix size image patches from each image instance.\n\n## Preparing Training Data\nCifar10 and Cifar100 can be downloaded automatically in the script. ImageNet100 is a special subset of ImageNet. Details can be found in this [link](https://github.com/HobbitLong/CMC/issues/21).\n\n## Getting Started\nCurrent code implementation supports Cifar10, Cifar100 and ImageNet100.\n\nTo get started with the EMP-SSL implementation, follow these instructions:\n\n### 1. Clone this repository\n```bash\ngit clone https://github.com/tsb0601/emp-ssl.git\ncd emp-ssl\n``` \n### 2. Install required packages\n```\npip install -r requirements.txt\n```\n### 3. Training\n\n#### Reproducing 1-epoch results\n\n|                    | CIFAR-10<br>1 Epoch | CIFAR-100<br>1 Epoch | Tiny ImageNet<br>1 epochs | ImageNet-100<br>1 epochs |\n|--------------------|:----------------------:|:-----------------------:|:----------------------------:|:--------------------------:|\n| EMP-SSL (1 Epoch)  |         0.842          |          0.585          |             0.381             |            0.585           |\n\nFor CIFAR10 or CIFAR100\n```\npython main.py --data cifar10 --epoch 2 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3\n```\nFor ImageNet100\n```\npython main.py --data imagenet100 --epoch 2 --patch_sim 200 --arch 'resnet18-imagenet' --num_patches 20 --lr 0.3\n```\n\n\n#### Reproducing multi epochs results\n\n|                      | CIFAR-10<br>1 Epoch | CIFAR-10<br>10 Epochs | CIFAR-10<br>30 Epochs | CIFAR-10<br>1000 Epochs | CIFAR-100<br>1 Epoch | CIFAR-100<br>10 Epochs | CIFAR-100<br>30 Epochs | CIFAR-100<br>1000 Epochs | Tiny ImageNet<br>10 Epochs | Tiny ImageNet<br>1000 Epochs |ImageNet-100<br>10 Epochs | ImageNet-100<br>400 Epochs |\n|----------------------|:-------------------:|:---------------------:|:---------------------:|:-----------------------:|:--------------------:|:----------------------:|:----------------------:|:------------------------:| :------------------------:|:------------------------:|:------------------------:| :------------------------:|\n| SimCLR               |        0.282        |         0.565         |         0.663         |          0.910          |         0.054        |         0.185          |         0.341          |          0.662           | - | 0.488 | - | 0.776\n| BYOL                 |        0.249        |         0.489         |         0.684         |          0.926          |         0.043        |         0.150          |         0.349          |          0.708           | - | 0.510 | - | 0.802\n| VICReg               |        0.406        |         0.697         |         0.781         |          0.921          |         0.079        |         0.319          |         0.479          |          0.685           | - | - | - | 0.792\n| SwAV                 |        0.245        |         0.532         |         0.767         |          0.923          |         0.028        |         0.208          |         0.294          |          0.658           |- | - | - | 0.740\n| ReSSL                |        0.245        |         0.256         |         0.525         |          0.914          |         0.033        |         0.122          |         0.247          |          0.674           |- | - | - | 0.769\n| EMP-SSL (20 patches) |        0.806        |         0.907         |         0.931         |            -            |         0.551        |         0.678          |         0.724          |            -              | - | - | - | -\n| EMP-SSL (200 patches)|        0.826*        |         0.915         |         0.934         |            -            |         0.577        |         0.701          |         0.733          |            -              | 0.515 | - | 0.789 | -\n\n\\* Here, we change learning rate schedule to decay in 30 epochs, so 1 epoch accuracy will be slightly lower than optimizing for 1-epoch training. \n\nChange num_patches here to change the number of patches used in EMP-SSL training.\n```\npython main.py --data cifar10 --epoch 30 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3\n```\n\n\n\n### 4. Evaluating\nBecause our model is trained with only fixed size image patches. To evaluate the performance, we adopt bag-of-features model from intra-instance VICReg paper. Change test_patches here to adjust number of patches used in bag-of-feature model for different GPUs.\n```\npython evaluate.py --model_path 'path to your evaluated model' --test_patches 128\n```\n\n## Acknowledgment\nThis repo is inspired by [MCR2](https://github.com/Ma-Lab-Berkeley/MCR2), [solo-learn](https://github.com/vturrisi/solo-learn) and [NMCE](https://github.com/zengyi-li/NMCE-release) repo.\n\n## Citation\nIf you find this repository useful, please consider giving a star :star: and citation:\n\n```\n@article{tong2023empssl,\ntitle={EMP-SSL: Towards Self-Supervised Learning in One Training Epoch},\nauthor={Shengbang Tong and Yubei Chen and Yi Ma and Yann Lecun},\njournal={arXiv preprint arXiv:2304.03977},\nyear={2023}\n}\n```\n"
  },
  {
    "path": "dataset/aug.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport torchvision.transforms as transforms\nfrom PIL import Image, ImageFilter, ImageOps\n\n\ndef load_transforms(name):\n    \"\"\"Load data transformations.\n    \n    Note:\n        - Gaussian Blur is defined at the bottom of this file.\n    \n    \"\"\"\n    _name = name.lower()\n    if _name == \"cifar_sup\":\n        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n        aug_transform = transforms.Compose([\n            transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),\n            transforms.RandomCrop(32, padding=8),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            normalize\n        ])\n        baseline_transform = transforms.Compose([\n            transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),\n            transforms.ToTensor(),normalize])\n\n    elif _name == \"cifar_patch\":\n        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])\n        aug_transform = transforms.Compose([\n            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),\n            transforms.ToTensor(),\n            normalize\n        ])\n        baseline_transform = transforms.Compose([\n            transforms.ToTensor(), normalize])\n        \n    elif _name == \"cifar_simclr_norm\":\n        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n        aug_transform = transforms.Compose([\n            transforms.RandomResizedCrop(32,scale=(0.08, 1.0)),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),\n            transforms.RandomGrayscale(p=0.2),\n            transforms.ToTensor(),\n            normalize\n        ])\n        baseline_transform = transforms.Compose([\n            transforms.ToTensor(),normalize])\n    \n    elif _name == \"cifar_byol\":\n        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])\n        aug_transform = transforms.Compose([\n            transforms.RandomResizedCrop(\n                    (32, 32),\n                    scale=(0.2, 1.0),\n                    interpolation=transforms.InterpolationMode.BICUBIC,\n                ),\n            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),\n            transforms.RandomGrayscale(p=0.2),\n            transforms.RandomApply([Solarization()], p=0.1),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.ToTensor(),\n            normalize\n        ])\n        baseline_transform = transforms.Compose([\n#             transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),\n            transforms.ToTensor(),normalize])\n\n    else:\n        raise NameError(\"{} not found in transform loader\".format(name))\n    return aug_transform, baseline_transform\n\n\nclass Solarization:\n    \"\"\"Solarization as a callable object.\"\"\"\n\n    def __call__(self, img: Image) -> Image:\n        \"\"\"Applies solarization to an input image.\n\n        Args:\n            img (Image): an image in the PIL.Image format.\n\n        Returns:\n            Image: a solarized image.\n        \"\"\"\n\n        return ImageOps.solarize(img)\n\nclass GBlur(object):\n    def __init__(self, p):\n        self.p = p\n\n    def __call__(self, img):\n        if np.random.rand() < self.p:\n            sigma = np.random.rand() * 1.9 + 0.1\n            return img.filter(ImageFilter.GaussianBlur(sigma))\n        else:\n            return img\n\n\nclass AddGaussianNoise(object):\n    def __init__(self, mean=0., std=1.):\n        self.std = std\n        self.mean = mean\n        \n    def __call__(self, tensor):\n        return tensor + torch.randn(tensor.size()) * self.std + self.mean\n    \n    def __repr__(self):\n        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)\n\n\nclass ContrastiveLearningViewGenerator(object):\n    def __init__(self, num_patch = 4):\n    \n        self.num_patch = num_patch\n      \n    def __call__(self, x):\n    \n    \n        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])\n        \n        aug_transform = transforms.Compose([\n            transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),\n            transforms.RandomGrayscale(p=0.2),\n            GBlur(p=0.1),\n            transforms.RandomApply([Solarization()], p=0.1),\n            transforms.ToTensor(),  \n            normalize\n        ])\n        augmented_x = [aug_transform(x) for i in range(self.num_patch)]\n     \n        return augmented_x\n"
  },
  {
    "path": "dataset/aug4img.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport torchvision.transforms as transforms\nfrom PIL import Image, ImageFilter, ImageOps\nfrom torchvision.transforms import InterpolationMode\n\n\n\nclass Solarization:\n    \"\"\"Solarization as a callable object.\"\"\"\n\n    def __call__(self, img: Image) -> Image:\n        \"\"\"Applies solarization to an input image.\n\n        Args:\n            img (Image): an image in the PIL.Image format.\n\n        Returns:\n            Image: a solarized image.\n        \"\"\"\n\n        return ImageOps.solarize(img)\n\nclass GBlur(object):\n    def __init__(self, p):\n        self.p = p\n\n    def __call__(self, img):\n        if np.random.rand() < self.p:\n            sigma = np.random.rand() * 1.9 + 0.1\n            return img.filter(ImageFilter.GaussianBlur(sigma))\n        else:\n            return img\n\n\nclass AddGaussianNoise(object):\n    def __init__(self, mean=0., std=1.):\n        self.std = std\n        self.mean = mean\n        \n    def __call__(self, tensor):\n        return tensor + torch.randn(tensor.size()) * self.std + self.mean\n    \n    def __repr__(self):\n        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)\n\n\nclass ContrastiveLearningViewGenerator(object):\n    def __init__(self, num_patch = 4):\n\n        self.num_patch = num_patch\n        \n    def __call__(self, x):\n        aug_transform =  transforms.Compose([\n            transforms.RandomResizedCrop(\n                224, scale=(0.25, 0.25), interpolation=InterpolationMode.BICUBIC\n            ),\n            transforms.RandomHorizontalFlip(p=0.5),\n            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),\n            transforms.RandomGrayscale(p=0.2),\n            GBlur(p=0.1),\n            transforms.RandomApply([Solarization()], p=0.1),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n            \n        ])\n\n        \n        augmented_x = [aug_transform(x) for i in range(self.num_patch)]\n       \n        \n        return augmented_x\n \n        \n"
  },
  {
    "path": "dataset/datasets.py",
    "content": "import os\nimport numpy as np\nimport torchvision\n\ndef load_dataset(data_name, train=True, num_patch = 4, path=\"./data/\"):\n    \"\"\"Loads a dataset for training and testing. If augmentloader is used, transform should be None.\n    \n    Parameters:\n        data_name (str): name of the dataset\n        transform_name (torchvision.transform): name of transform to be applied (see aug.py)\n        use_baseline (bool): use baseline transform or augmentation transform\n        train (bool): load training set or not\n        contrastive (bool): whether to convert transform to multiview augmentation for contrastive learning.\n        n_views (bool): number of views for contrastive learning\n        path (str): path to dataset base path\n\n    Returns:\n        dataset (torch.data.dataset)\n    \"\"\"\n    _name = data_name.lower()\n    if _name == \"imagenet\":\n        from .aug4img import ContrastiveLearningViewGenerator\n    else:\n        from .aug import ContrastiveLearningViewGenerator\n      \n    \n    transform = ContrastiveLearningViewGenerator(num_patch = num_patch)\n        \n    if _name == \"cifar10\":\n        trainset = torchvision.datasets.CIFAR10(root=os.path.join(path, \"CIFAR10\"), train=train, download=True, transform=transform)\n        trainset.num_classes = 10\n    elif _name == \"cifar100\":\n        trainset = torchvision.datasets.CIFAR100(root=os.path.join(path, \"CIFAR100\"), train=train, download=True, transform=transform)\n        trainset.num_classes = 100\n    elif _name == \"imagenet\":\n        if train:\n            trainset = torchvision.datasets.ImageFolder(root=\"/home/peter/Data/ILSVRC2012/train100/\",transform=transform)\n            #trainset = torchvision.datasets.ImageFolder(root=\"/home/peter/Data/tiny-imagenet-200/train/\",transform=transform)\n        else:\n            trainset = torchvision.datasets.ImageFolder(root=\"/home/peter/Data/ILSVRC2012/val100/\",transform=transform)\n            #trainset = torchvision.datasets.ImageFolder(root=\"/home/peter/Data/tiny-imagenet-200/val/\",transform=transform)\n        trainset.num_classes = 200  \n        \n    else:\n        raise NameError(\"{} not found in trainset loader\".format(_name))\n    return trainset\n\ndef sparse2coarse(targets):\n    \"\"\"CIFAR100 Coarse Labels. \"\"\"\n    coarse_targets = [ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  3, 14,  9, 18,  7, 11,  3,\n                       9,  7, 11,  6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  0, 11,  1, 10,\n                      12, 14, 16,  9, 11,  5,  5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 16,\n                       4, 17,  4,  2,  0, 17,  4, 18, 17, 10,  3,  2, 12, 12, 16, 12,  1,\n                       9, 19,  2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 16, 19,  2,  4,  6,\n                      19,  5,  5,  8, 19, 18,  1,  2, 15,  6,  0, 17,  8, 14, 13]\n    return np.array(coarse_targets)[targets]"
  },
  {
    "path": "evaluate.py",
    "content": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\nfrom model.model import encoder\nfrom dataset.datasets import load_dataset\nimport numpy as np\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nimport torch\nimport numpy as np\nfrom func import WeightedKNNClassifier, linear\n\n######################\n## Parsing Argument ##\n######################\nimport argparse\nparser = argparse.ArgumentParser(description='Evaluation')\n\nparser.add_argument('--test_patches', type=int, default=128,\n                    help='number of patches used in testing (default: 128)')  \n\nparser.add_argument('--data', type=str, default=\"cifar10\",\n                    help='dataset (default: cifar10)')  \nparser.add_argument('--arch', type=str, default=\"resnet18-cifar\",\n                    help='network architecture (default: resnet18-cifar)')\n\nparser.add_argument('--lr', type=float, default=0.03,\n                    help='learning rate for linear eval (default: 0.03)')        \nparser.add_argument('--linear', type=bool, default=True,\n                    help='use linear eval or not')\nparser.add_argument('--knn', help='evaluate using kNN measuring cosine similarity', action='store_true')\nparser.add_argument('--model_path', type=str, default=\"\",\n                    help='model directory for eval')\n\n            \nargs = parser.parse_args()\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n######################\n## Testing Accuracy ##\n######################\ntest_patches = args.test_patches\n\ndef compute_accuracy(y_pred, y_true):\n    \"\"\"Compute accuracy by counting correct classification. \"\"\"\n    assert y_pred.shape == y_true.shape\n    return 1 - np.count_nonzero(y_pred - y_true) / y_true.size\n\nknn_classifier = WeightedKNNClassifier()\n\n\ndef chunk_avg(x,n_chunks=2,normalize=False):\n    x_list = x.chunk(n_chunks,dim=0)\n    x = torch.stack(x_list,dim=0)\n    if not normalize:\n        return x.mean(0)\n    else:\n        return F.normalize(x.mean(0),dim=1)\n\n\ndef test(net, train_loader, test_loader):\n    \n    train_z_full_list, train_y_list, test_z_full_list, test_y_list = [], [], [], []\n    \n    with torch.no_grad():\n        for x, y in tqdm(train_loader):\n\n            x = torch.cat(x, dim = 0)\n            \n            z_proj, z_pre = net(x, is_test=True)\n\n            z_pre = chunk_avg(z_pre, test_patches)\n            z_pre = z_pre.detach().cpu()\n            \n            \n            train_z_full_list.append(z_pre)\n            \n            \n            knn_classifier.update(train_features = z_pre, train_targets = y)\n\n            train_y_list.append(y)\n                \n        for x, y in tqdm(test_loader):\n            x = torch.cat(x, dim = 0)\n            \n            z_proj, z_pre = net(x, is_test=True)\n\n            z_pre = chunk_avg(z_pre, test_patches)\n            z_pre = z_pre.detach().cpu()\n           \n            test_z_full_list.append(z_pre)\n       \n            knn_classifier.update(test_features = z_pre, test_targets = y)\n\n            test_y_list.append(y)\n                \n            \n    train_features_full, train_labels, test_features_full, test_labels = torch.cat(train_z_full_list,dim=0), torch.cat(train_y_list,dim=0), torch.cat(test_z_full_list,dim=0), torch.cat(test_y_list,dim=0)\n   \n    if args.data == \"cifar10\":\n        num_classes = 10\n    elif args.data == \"cifar100\":\n        num_classes = 100\n    elif args.data == \"tinyimagenet200\":\n        num_classes = 200\n    elif args.data == \"imagenet100\":\n        num_classes = 100\n    elif args.data == \"imagenet\":\n        num_classes = 1000\n        \n    if args.linear:\n        print(\"Using Linear Eval to evaluate accuracy\")\n        linear(train_features_full, train_labels, test_features_full, test_labels, lr=args.lr, num_classes = num_classes)\n    \n    if args.knn:\n        print(\"Using KNN to evaluate accuracy\")\n        top1, top5 = knn_classifier.compute()\n        print(\"KNN (top1/top5):\", top1, top5)\n    \ndef chunk_avg(x,n_chunks=2,normalize=False):\n    x_list = x.chunk(n_chunks,dim=0)\n    x = torch.stack(x_list,dim=0)\n    if not normalize:\n        return x.mean(0)\n    else:\n        return F.normalize(x.mean(0),dim=1)\n\n\ntorch.multiprocessing.set_sharing_strategy('file_system')\n\n\n#Get Dataset\nif args.data == \"imagenet100\" or args.data == \"imagenet\":\n        \n    memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)\n    memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)\n\n    test_data = load_dataset(args.data, train=False, num_patch = test_patches)\n    test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)\n\nelse:\n    memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)\n    memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)\n\n    test_data = load_dataset(args.data, train=False, num_patch = test_patches)\n    test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)\n\n# Load Model and Checkpoint\nuse_cuda = True\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\nnet = encoder(arch = args.arch)\nnet = nn.DataParallel(net)\nsave_dict = torch.load(args.model_path)\nnet.load_state_dict(save_dict,strict=False)\nnet.cuda()\nnet.eval()\ntest(net, memory_loader, test_loader)\n\n\n\n"
  },
  {
    "path": "func.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score\nimport torchvision\n# import torch.nn\n\n\nfrom torch import nn, optim\nimport torch.nn as nn\nfrom torch.utils import data\nfrom torch.utils.data import DataLoader\n\n\n\nfrom typing import Tuple\n\n\nimport torch.nn.functional as F\nfrom torchmetrics.metric import Metric\n\n\nclass WeightedKNNClassifier(Metric):\n    def __init__(\n        self,\n        k: int = 20,\n        T: float = 0.07,\n        max_distance_matrix_size: int = int(5e6),\n        distance_fx: str = \"cosine\",\n        epsilon: float = 0.00001,\n        dist_sync_on_step: bool = False,\n    ):\n        \"\"\"Implements the weighted k-NN classifier used for evaluation.\n        Args:\n            k (int, optional): number of neighbors. Defaults to 20.\n            T (float, optional): temperature for the exponential. Only used with cosine\n                distance. Defaults to 0.07.\n            max_distance_matrix_size (int, optional): maximum number of elements in the\n                distance matrix. Defaults to 5e6.\n            distance_fx (str, optional): Distance function. Accepted arguments: \"cosine\" or\n                \"euclidean\". Defaults to \"cosine\".\n            epsilon (float, optional): Small value for numerical stability. Only used with\n                euclidean distance. Defaults to 0.00001.\n            dist_sync_on_step (bool, optional): whether to sync distributed values at every\n                step. Defaults to False.\n        \"\"\"\n\n        super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)\n\n        self.k = k\n        self.T = T\n        self.max_distance_matrix_size = max_distance_matrix_size\n        self.distance_fx = distance_fx\n        self.epsilon = epsilon\n\n        self.add_state(\"train_features\", default=[], persistent=False)\n        self.add_state(\"train_targets\", default=[], persistent=False)\n        self.add_state(\"test_features\", default=[], persistent=False)\n        self.add_state(\"test_targets\", default=[], persistent=False)\n\n    def update(\n        self,\n        train_features: torch.Tensor = None,\n        train_targets: torch.Tensor = None,\n        test_features: torch.Tensor = None,\n        test_targets: torch.Tensor = None,\n    ):\n        \"\"\"Updates the memory banks. If train (test) features are passed as input, the\n        corresponding train (test) targets must be passed as well.\n        Args:\n            train_features (torch.Tensor, optional): a batch of train features. Defaults to None.\n            train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.\n            test_features (torch.Tensor, optional): a batch of test features. Defaults to None.\n            test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.\n        \"\"\"\n        assert (train_features is None) == (train_targets is None)\n        assert (test_features is None) == (test_targets is None)\n\n        if train_features is not None:\n            assert train_features.size(0) == train_targets.size(0)\n            self.train_features.append(train_features.detach())\n            self.train_targets.append(train_targets.detach())\n\n        if test_features is not None:\n            assert test_features.size(0) == test_targets.size(0)\n            self.test_features.append(test_features.detach())\n            self.test_targets.append(test_targets.detach())\n\n    def set_tk(self, T, k):\n        self.T = T\n        self.k = k\n        \n    @torch.no_grad()\n    def compute(self) -> Tuple[float]:\n        \"\"\"Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,\n        the weight is computed using the exponential of the temperature scaled cosine\n        distance of the samples. If euclidean distance is selected, the weight corresponds\n        to the inverse of the euclidean distance.\n        Returns:\n            Tuple[float]: k-NN accuracy @1 and @5.\n        \"\"\"\n        \n        #print(self.T, self.k)\n\n        train_features = torch.cat(self.train_features)\n        train_targets = torch.cat(self.train_targets)\n        test_features = torch.cat(self.test_features)\n        test_targets = torch.cat(self.test_targets)\n\n        if self.distance_fx == \"cosine\":\n            train_features = F.normalize(train_features)\n            test_features = F.normalize(test_features)\n\n        num_classes = torch.unique(test_targets).numel()\n        num_train_images = train_targets.size(0)\n        num_test_images = test_targets.size(0)\n        num_train_images = train_targets.size(0)\n        chunk_size = min(\n            max(1, self.max_distance_matrix_size // num_train_images),\n            num_test_images,\n        )\n        k = min(self.k, num_train_images)\n\n        top1, top5, total = 0.0, 0.0, 0\n        retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)\n        for idx in range(0, num_test_images, chunk_size):\n            # get the features for test images\n            features = test_features[idx : min((idx + chunk_size), num_test_images), :]\n            targets = test_targets[idx : min((idx + chunk_size), num_test_images)]\n            batch_size = targets.size(0)\n\n            # calculate the dot product and compute top-k neighbors\n            if self.distance_fx == \"cosine\":\n                similarities = torch.mm(features, train_features.t())\n            elif self.distance_fx == \"euclidean\":\n                similarities = 1 / (torch.cdist(features, train_features) + self.epsilon)\n            else:\n                raise NotImplementedError\n\n            similarities, indices = similarities.topk(k, largest=True, sorted=True)\n            candidates = train_targets.view(1, -1).expand(batch_size, -1)\n            retrieved_neighbors = torch.gather(candidates, 1, indices)\n\n            retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()\n            retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)\n\n            if self.distance_fx == \"cosine\":\n                similarities = similarities.clone().div_(self.T).exp_()\n\n            probs = torch.sum(\n                torch.mul(\n                    retrieval_one_hot.view(batch_size, -1, num_classes),\n                    similarities.view(batch_size, -1, 1),\n                ),\n                1,\n            )\n            _, predictions = probs.sort(1, True)\n\n            # find the predictions that match the target\n            correct = predictions.eq(targets.data.view(-1, 1))\n            top1 = top1 + correct.narrow(1, 0, 1).sum().item()\n            top5 = (\n                top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()\n            )  # top5 does not make sense if k < 5\n            total += targets.size(0)\n\n        top1 = top1 * 100.0 / total\n        top5 = top5 * 100.0 / total\n\n        self.reset()\n\n        return top1, top5\n\n\n\n\n\n\n\ndef linear(train_features, train_labels, test_features, test_labels, lr=0.0075, num_classes = 100):\n\n\n   \n    \n    train_data = tensor_dataset(train_features,train_labels)\n    test_data = tensor_dataset(test_features,test_labels)\n    train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True, num_workers=2)\n    test_loader = DataLoader(test_data, batch_size=100, shuffle=True, drop_last=False, num_workers=2)\n    \n    LL = nn.Linear(train_features.shape[1],num_classes)\n    optimizer = torch.optim.SGD(LL.parameters(), lr=lr, momentum=0.9, weight_decay=5e-5)\n    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)\n    \n    criterion = torch.nn.CrossEntropyLoss()\n    \n    test_acc_list = []\n    for epoch in range(100):\n        top1_train_accuracy = 0\n        for counter, (x_batch, y_batch) in enumerate(train_loader):\n            x_batch = x_batch\n            y_batch = y_batch\n            \n            logits = LL(x_batch)\n            loss = criterion(logits, y_batch)\n            top1 = accuracy(logits, y_batch, topk=(1,))\n            top1_train_accuracy += top1[0]\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            \n        scheduler.step() \n\n        top1_train_accuracy /= (counter + 1)\n\n        top1_accuracy = 0\n        top5_accuracy = 0\n        for counter, (x_batch, y_batch) in enumerate(test_loader):\n            x_batch = x_batch\n            y_batch = y_batch\n\n            logits = LL(x_batch)\n\n            top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n            top1_accuracy += top1[0]\n            top5_accuracy += top5[0]\n\n        top1_accuracy /= (counter + 1)\n        top5_accuracy /= (counter + 1)\n        \n        test_acc_list.append(top1_accuracy)\n        \n        print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")\n    acc_vect = torch.tensor(test_acc_list)\n    print('best linear test acc {}, last acc {}'.format(acc_vect.max().item(),acc_vect[-1].item()))\n        \n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n    with torch.no_grad():\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n        res = []\n        for k in topk:\n            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n            res.append(correct_k.mul_(100.0 / batch_size))\n        return res\n            \nclass tensor_dataset(data.Dataset):\n    def __init__(self,x,y):\n        self.x = x\n        self.y = y\n        self.length = x.shape[0]\n    \n    def __getitem__(self,indx):\n        return self.x[indx], self.y[indx]\n    \n    def __len__(self):\n        return self.length\n\n\n\n\ndef set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0.,gamma_max=1.0):\n    warmup_start = total_epoch - warmup_epoch\n    warmup_end = total_epoch\n    \n    if warmup_start < epoch<=warmup_end:\n        loss_fn.gamma = ((epoch - warmup_start)/(warmup_end - warmup_start))*(gamma_max - gamma_min) + gamma_min\n    else:\n        loss_fn.gamma = gamma_min\n\ndef warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):\n    if epoch<warmup_epoch:\n        optimizer.param_groups[0]['lr'] = base_lr*min(1.,(epoch+1)/warmup_epoch)\n        \n        \ndef marginal_H(logits):\n    bs = torch.tensor(logits.shape[0]).float()\n    logps = torch.log_softmax(logits,dim=1)\n    marginal_p = torch.logsumexp(logps - bs.log(),dim=0)\n    H = (marginal_p.exp()*(-marginal_p)).sum()*(1.4426950)\n    return H\n\ndef chunk_avg(x,n_chunks=2,normalize=False):\n    x_list = x.chunk(n_chunks,dim=0)\n    x = torch.stack(x_list,dim=0)\n    if not normalize:\n        return x.mean(0)\n    else:\n        return F.normalize(x.mean(0),dim=1)\n\ndef cluster_match(cluster_mtx,label_mtx,n_classes=10,print_result=True):\n    #verified to be consistent to optimimal assignment problem based algorithm\n    cluster_indx = list(cluster_mtx.unique())\n    assigned_label_list = []\n    assigned_count = []\n    while (len(assigned_label_list)<=n_classes) and len(cluster_indx)>0:\n        max_label_list = []\n        max_count_list = []\n        for indx in cluster_indx:\n            #calculate highest number of matchs\n            mask = cluster_mtx==indx\n            label_elements, counts = label_mtx[mask].unique(return_counts=True)\n            for assigned_label in assigned_label_list:\n                counts[label_elements==assigned_label] = 0\n            max_count_list.append(counts.max())\n            max_label_list.append(label_elements[counts.argmax()])\n\n        max_label = torch.stack(max_label_list)\n        max_count = torch.stack(max_count_list)\n        assigned_label_list.append(max_label[max_count.argmax()])\n        assigned_count.append(max_count.max())\n        cluster_indx.pop(max_count.argmax())\n    total_correct = torch.tensor(assigned_count).sum().item()\n    total_sample = cluster_mtx.shape[0]\n    acc = total_correct/total_sample\n    if print_result:\n        print('{}/{} ({}%) correct'.format(total_correct,total_sample,acc*100))\n    else:\n        return total_correct, total_sample, acc\n\ndef cluster_merge_match(cluster_mtx,label_mtx,print_result=True):\n    cluster_indx = list(cluster_mtx.unique())\n    n_correct = 0\n    for cluster_id in cluster_indx:\n        label_elements, counts = label_mtx[cluster_mtx==cluster_id].unique(return_counts=True)\n        n_correct += counts.max()\n    total_sample = len(cluster_mtx)\n    acc = n_correct.item()/total_sample\n    if print_result:\n        print('{}/{} ({}%) correct'.format(n_correct,total_sample,acc*100))\n    else:\n        return n_correct, total_sample, acc\n\n    \ndef cluster_acc(test_loader,net,device,print_result=False,save_name_img='cluster_img',save_name_fig='pca_figure'):\n    cluster_list = []\n    label_list = []\n    x_list = []\n    z_list = []\n    net.eval()\n    for x, y in test_loader:\n        with torch.no_grad():\n            x, y = x.float().to(device), y.to(device)\n            z, logit = net(x)\n            if logit.sum() == 0:\n                logit += torch.randn_like(logit)\n            cluster_list.append(logit.max(dim=1)[1].cpu())\n            label_list.append(y.cpu())\n            x_list.append(x.cpu())\n            z_list.append(z.cpu())\n    net.train()\n    cluster_mtx = torch.cat(cluster_list,dim=0)\n    label_mtx = torch.cat(label_list,dim=0)\n    x_mtx = torch.cat(x_list,dim=0)\n    z_mtx = torch.cat(z_list,dim=0)\n    _, _, acc_single = cluster_match(cluster_mtx,label_mtx,n_classes=label_mtx.max()+1,print_result=False)\n    _, _, acc_merge = cluster_merge_match(cluster_mtx,label_mtx,print_result=False)\n    NMI = normalized_mutual_info_score(label_mtx.numpy(),cluster_mtx.numpy())\n    ARI = adjusted_rand_score(label_mtx.numpy(),cluster_mtx.numpy())\n    if print_result:\n        print('cluster match acc {}, cluster merge match acc {}, NMI {}, ARI {}'.format(acc_single,acc_merge,NMI,ARI))\n    \n    save_name_img += '_acc'+ str(acc_single)[2:5]\n    save_cluster_imgs(cluster_mtx,x_mtx,save_name_img)\n    save_latent_pca_figure(z_mtx,cluster_mtx,save_name_fig)\n    \n    return acc_single, acc_merge, NMI, ARI\n    \ndef save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):\n    cluster_indexs, counts = cluster_mtx.unique(return_counts=True)\n    x_list = []\n    counts_list = []\n    for i, c_indx in enumerate(cluster_indexs):\n        if counts[i]>npercluster:\n            x_list.append(x_mtx[cluster_mtx==c_indx,:,:,:])\n            counts_list.append(counts[i])\n\n    n_clusters = len(counts_list)\n    fig, ax = plt.subplots(n_clusters,1,dpi=80,figsize=(1.2*n_clusters, 3*n_clusters))\n    for i, ax in enumerate(ax):\n        img = torchvision.utils.make_grid(x_list[i][:npercluster],nrow=npercluster//5,normalize=True)\n        ax.imshow(img.permute(1,2,0))\n        ax.set_axis_off()\n\n        ax.set_title('Cluster with {} images'.format(counts_list[i]))\n    \n    fig.savefig(save_name+'.pdf')\n    plt.close(fig)\n    \ndef save_latent_pca_figure(z_mtx,cluster_mtx,save_name):\n    _, s_z_all, _ = z_mtx.svd()\n    cluster_n = []\n    cluster_s = []\n    for cluster_indx in cluster_mtx.unique():\n        _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()\n        cluster_n.append((cluster_mtx==cluster_indx).sum().item())\n        cluster_s.append(s_cluster/s_cluster.max())\n\n    #make plot\n    fig, ax = plt.subplots(1,2,figsize=(9, 3))\n    ax[0].plot(s_z_all)\n    for i, s_curve in enumerate(cluster_s):\n        ax[1].plot(s_curve,label=cluster_n[i])\n    ax[1].set_xlim(xmin=0,xmax=20)\n    ax[1].legend()\n    fig.savefig(save_name +'.pdf')\n    plt.close(fig)\n    \ndef analyze_latent(z_mtx,cluster_mtx):\n    _, s_z_all, _ = z_mtx.svd()\n    cluster_n = []\n    cluster_s = []\n    cluster_d = []\n    for cluster_indx in cluster_mtx.unique():\n        _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()\n        s_cluster = s_cluster/s_cluster.max()\n        cluster_n.append((cluster_mtx==cluster_indx).sum().item())\n        cluster_s.append(s_cluster)\n#         print(list(cluster_s))\n        print(s_cluster)\n#         s_diff = s_cluster[:-1] - s_cluster[1:]\n#         cluster_d.append(s_diff.max(0)[1])\n        cluster_d.append((s_cluster>0.01).sum())\n    for i in range(len(cluster_n)):\n        print('subspace {}, dimension {}, samples {}'.format(i,cluster_d[i],cluster_n[i]))"
  },
  {
    "path": "lars.py",
    "content": "import torch\nimport torch.optim as optim\nfrom torch.optim.optimizer import Optimizer, required\n\nclass LARS(Optimizer):\n    \"\"\"\n    Layer-wise adaptive rate scaling\n    - Converted from Tensorflow to Pytorch from:\n    https://github.com/google-research/simclr/blob/master/lars_optimizer.py\n    - Based on:\n    https://github.com/noahgolmant/pytorch-lars\n    params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float): base learning rate (\\gamma_0)\n        lr (int): Length / Number of layers we want to apply weight decay, else do not compute\n        momentum (float, optional): momentum factor (default: 0.9)\n        use_nesterov (bool, optional): flag to use nesterov momentum (default: False)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)\n            (\"\\beta\")\n        eta (float, optional): LARS coefficient (default: 0.001)\n    - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.\n    - Large Batch Training of Convolutional Networks:\n        https://arxiv.org/abs/1708.03888\n    \"\"\"\n\n    def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001):\n\n        self.epoch = 0\n        defaults = dict(\n            lr=lr,\n            momentum=momentum,\n            use_nesterov=use_nesterov,\n            weight_decay=weight_decay,\n            classic_momentum=classic_momentum,\n            eta=eta,\n            len_reduced=len_reduced\n        )\n\n        super(LARS, self).__init__(params, defaults)\n        self.lr = lr\n        self.momentum = momentum\n        self.weight_decay = weight_decay\n        self.use_nesterov = use_nesterov\n        self.classic_momentum = classic_momentum\n        self.eta = eta\n        self.len_reduced = len_reduced\n\n    def step(self, epoch=None, closure=None):\n\n        loss = None\n\n        if closure is not None:\n            loss = closure()\n\n        if epoch is None:\n            epoch = self.epoch\n            self.epoch += 1\n\n        for group in self.param_groups:\n            weight_decay = group['weight_decay']\n            momentum = group['momentum']\n            eta = group['eta']\n            learning_rate = group['lr']\n\n            # TODO: Hacky\n            counter = 0\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n\n                param = p.data\n                grad = p.grad.data\n\n                param_state = self.state[p]\n\n                # TODO: This really hacky way needs to be improved.\n                # Note Excluded are passed at the end of the list to are ignored\n                if counter < self.len_reduced:\n                    grad += self.weight_decay * param\n\n                # Create parameter for the momentum\n                if \"momentum_var\" not in param_state:\n                    next_v = param_state[\"momentum_var\"] = torch.zeros_like(\n                        p.data\n                    )\n                else:\n                    next_v = param_state[\"momentum_var\"]\n\n                if self.classic_momentum:\n                    trust_ratio = 1.0\n\n                    # TODO: implementation of layer adaptation\n                    w_norm = torch.norm(param)\n                    g_norm = torch.norm(grad)\n\n                    device = g_norm.get_device()\n\n                    trust_ratio = torch.where(w_norm.ge(0), torch.where(\n                        g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)),\n                                              torch.Tensor([1.0]).to(device)).item()\n\n                    scaled_lr = learning_rate * trust_ratio\n                    \n                    grad_scaled = scaled_lr*grad\n                    next_v.mul_(momentum).add_(grad_scaled)\n\n                    if self.use_nesterov:\n                        update = (self.momentum * next_v) + (scaled_lr * grad)\n                    else:\n                        update = next_v\n\n                    p.data.add_(-update)\n\n                # Not classic_momentum\n                else:\n\n                    next_v.mul_(momentum).add_(grad)\n\n                    if self.use_nesterov:\n                        update = (self.momentum * next_v) + (grad)\n\n                    else:\n                        update = next_v\n\n                    trust_ratio = 1.0\n\n                    # TODO: implementation of layer adaptation\n                    w_norm = torch.norm(param)\n                    v_norm = torch.norm(update)\n\n                    device = v_norm.get_device()\n\n                    trust_ratio = torch.where(w_norm.ge(0), torch.where(\n                        v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)),\n                                              torch.Tensor([1.0]).to(device)).item()\n\n                    scaled_lr = learning_rate * trust_ratio\n\n                    p.data.add_(-scaled_lr * update)\n\n                counter += 1\n\n        return loss\n    \n#LARSWrapper from solo-learn repo...\nclass LARSWrapper:\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        eta: float = 1e-3,\n        clip: bool = False,\n        eps: float = 1e-8,\n        exclude_bias_n_norm: bool = False,\n    ):\n        \"\"\"Wrapper that adds LARS scheduling to any optimizer.\n        This helps stability with huge batch sizes.\n\n        Args:\n            optimizer (Optimizer): torch optimizer.\n            eta (float, optional): trust coefficient. Defaults to 1e-3.\n            clip (bool, optional): clip gradient values. Defaults to False.\n            eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.\n            exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.\n                Defaults to False.\n        \"\"\"\n\n        self.optim = optimizer\n        self.eta = eta\n        self.eps = eps\n        self.clip = clip\n        self.exclude_bias_n_norm = exclude_bias_n_norm\n\n        # transfer optim methods\n        self.state_dict = self.optim.state_dict\n        self.load_state_dict = self.optim.load_state_dict\n        self.zero_grad = self.optim.zero_grad\n        self.add_param_group = self.optim.add_param_group\n\n        self.__setstate__ = self.optim.__setstate__  # type: ignore\n        self.__getstate__ = self.optim.__getstate__  # type: ignore\n        self.__repr__ = self.optim.__repr__  # type: ignore\n\n    @property\n    def defaults(self):\n        return self.optim.defaults\n\n    @defaults.setter\n    def defaults(self, defaults):\n        self.optim.defaults = defaults\n\n    @property  # type: ignore\n    def __class__(self):\n        return Optimizer\n\n    @property\n    def state(self):\n        return self.optim.state\n\n    @state.setter\n    def state(self, state):\n        self.optim.state = state\n\n    @property\n    def param_groups(self):\n        return self.optim.param_groups\n\n    @param_groups.setter\n    def param_groups(self, value):\n        self.optim.param_groups = value\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        weight_decays = []\n\n        for group in self.optim.param_groups:\n            weight_decay = group.get(\"weight_decay\", 0)\n            weight_decays.append(weight_decay)\n\n            # reset weight decay\n            group[\"weight_decay\"] = 0\n\n            # update the parameters\n            for p in group[\"params\"]:\n                if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):\n                    self.update_p(p, group, weight_decay)\n\n        # update the optimizer\n        self.optim.step(closure=closure)\n\n        # return weight decay control to optimizer\n        for group_idx, group in enumerate(self.optim.param_groups):\n            group[\"weight_decay\"] = weight_decays[group_idx]\n\n    def update_p(self, p, group, weight_decay):\n        # calculate new norms\n        p_norm = torch.norm(p.data)\n        g_norm = torch.norm(p.grad.data)\n\n        if p_norm != 0 and g_norm != 0:\n            # calculate new lr\n            new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)\n\n            # clip lr\n            if self.clip:\n                new_lr = min(new_lr / group[\"lr\"], 1)\n\n            # update params with clipped lr\n            p.grad.data += weight_decay * p.data\n            p.grad.data *= new_lr"
  },
  {
    "path": "loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass contrastive_loss(nn.Module):\n    def __init__(self):\n        super().__init__()\n        pass\n    def forward(self,x,labels):\n        #this function assums that positive logit is always the first element.\n        #Which is true here\n        loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1)\n        return loss.mean()\n\nclass SimCLR(nn.Module):\n    def __init__(self,temperature=0.5,n_views=2,contrastive=False):\n        super(SimCLR,self).__init__()\n        self.temp = temperature\n        self.n_views = n_views\n        \n        if contrastive:\n            self.criterion = contrastive_loss()\n        else:\n            self.criterion = torch.nn.CrossEntropyLoss()\n        \n    def info_nce_loss(self,X):\n        \n        bs, n_dim = X.shape\n        bs = int(bs/self.n_views)\n        device = X.device\n        \n        \n        labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)\n        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()\n        labels = labels.to(device)\n\n        similarity_matrix = torch.matmul(X, X.T)\n        # assert similarity_matrix.shape == (\n        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)\n        # assert similarity_matrix.shape == labels.shape\n\n        # discard the main diagonal from both: labels and similarities matrix\n        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)\n        labels = labels[~mask].view(labels.shape[0], -1)\n        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)\n        # assert similarity_matrix.shape == labels.shape\n\n        # select and combine multiple positives\n        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)\n\n        # select only the negatives\n        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)\n\n        logits = torch.cat([positives, negatives], dim=1)\n        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)\n        \n        logits = logits / self.temp\n        return logits, labels\n        \n    def forward(self,X):\n        logits, labels = self.info_nce_loss(X)\n        loss = self.criterion(logits, labels)\n        return loss\n\nclass Z_loss(nn.Module):\n    def __init__(self,):\n        super().__init__()\n        pass\n        \n    def forward(self,z):\n        z_list = z.chunk(2,dim=0)\n        z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean()\n        z_sim_out = z_sim.clone().detach()\n        return -z_sim, z_sim_out\n\nclass TotalCodingRate(nn.Module):\n    def __init__(self, eps=0.01):\n        super(TotalCodingRate, self).__init__()\n        self.eps = eps\n        \n    def compute_discrimn_loss(self, W):\n        \"\"\"Discriminative Loss.\"\"\"\n        p, m = W.shape  #[d, B]\n        I = torch.eye(p,device=W.device)\n        scalar = p / (m * self.eps)\n        logdet = torch.logdet(I + scalar * W.matmul(W.T))\n        return logdet / 2.\n    \n    def forward(self,X):\n        return - self.compute_discrimn_loss(X.T)\n\nclass MaximalCodingRateReduction(torch.nn.Module):\n    def __init__(self, eps=0.01, gamma=1):\n        super(MaximalCodingRateReduction, self).__init__()\n        self.eps = eps\n        self.gamma = gamma\n        \n    def compute_discrimn_loss(self, W):\n        \"\"\"Discriminative Loss.\"\"\"\n        p, m = W.shape\n        I = torch.eye(p,device=W.device)\n        scalar = p / (m * self.eps)\n        logdet = torch.logdet(I + scalar * W.matmul(W.T))\n        return logdet / 2.\n    \n    def compute_compress_loss(self, W, Pi):\n        p, m = W.shape\n        k, _, _ = Pi.shape\n        I = torch.eye(p,device=W.device).expand((k,p,p))\n        trPi = Pi.sum(2) + 1e-8\n        scale = (p/(trPi*self.eps)).view(k,1,1)\n        \n        W = W.view((1,p,m))\n        log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2)))\n        compress_loss = (trPi.squeeze()*log_det/(2*m)).sum()\n        return compress_loss\n        \n    def forward(self, X, Y, num_classes=None):\n        #This function support Y as label integer or membership probablity.\n        if len(Y.shape)==1:\n            #if Y is a label vector\n            if num_classes is None:\n                num_classes = Y.max() + 1\n            Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device)\n            for indx, label in enumerate(Y):\n                Pi[label,0,indx] = 1\n        else:\n            #if Y is a probility matrix\n            if num_classes is None:\n                num_classes = Y.shape[1]\n            Pi = Y.T.reshape((num_classes,1,-1))\n            \n        W = X.T\n        discrimn_loss = self.compute_discrimn_loss(W)\n        compress_loss = self.compute_compress_loss(W, Pi)\n \n        total_loss = - discrimn_loss + self.gamma*compress_loss\n        return total_loss, [discrimn_loss.item(), compress_loss.item()]"
  },
  {
    "path": "main.py",
    "content": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nimport torch.optim as optim\nimport os\nfrom torch.utils.data import DataLoader\nfrom model.model import encoder\nfrom dataset.datasets import load_dataset\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as FF\nfrom tqdm import tqdm\nimport torch\nfrom torchvision.datasets import CIFAR10\nfrom loss import TotalCodingRate\nfrom func import chunk_avg\nfrom lars import LARS, LARSWrapper\nfrom func import WeightedKNNClassifier\nimport torch.optim.lr_scheduler as lr_scheduler\nfrom torch.cuda.amp import GradScaler, autocast\n\n######################\n## Parsing Argument ##\n######################\nimport argparse\nparser = argparse.ArgumentParser(description='Unsupervised Learning')\n\nparser.add_argument('--patch_sim', type=int, default=200,\n                    help='coefficient of cosine similarity (default: 200)')\nparser.add_argument('--tcr', type=int, default=1,\n                    help='coefficient of tcr (default: 1)')\nparser.add_argument('--num_patches', type=int, default=100,\n                    help='number of patches used in EMP-SSL (default: 100)')\nparser.add_argument('--arch', type=str, default=\"resnet18-cifar\",\n                    help='network architecture (default: resnet18-cifar)')\nparser.add_argument('--bs', type=int, default=100,\n                    help='batch size (default: 100)')\nparser.add_argument('--lr', type=float, default=0.3,\n                    help='learning rate (default: 0.3)')        \nparser.add_argument('--eps', type=float, default=0.2,\n                    help='eps for TCR (default: 0.2)') \nparser.add_argument('--msg', type=str, default=\"NONE\",\n                    help='additional message for description (default: NONE)')     \nparser.add_argument('--dir', type=str, default=\"EMP-SSL-Training\",\n                    help='directory name (default: EMP-SSL-Training)')     \nparser.add_argument('--data', type=str, default=\"cifar10\",\n                    help='data (default: cifar10)')          \nparser.add_argument('--epoch', type=int, default=30,\n                    help='max number of epochs to finish (default: 30)')  \n\nargs = parser.parse_args()\n\nprint(args)\n\nnum_patches = args.num_patches\ndir_name = f\"./logs/{args.dir}/patchsim{args.patch_sim}_numpatch{args.num_patches}_bs{args.bs}_lr{args.lr}_{args.msg}\"\n\n\n\n#####################\n## Helper Function ##\n#####################\n\ndef chunk_avg(x,n_chunks=2,normalize=False):\n    x_list = x.chunk(n_chunks,dim=0)\n    x = torch.stack(x_list,dim=0)\n    if not normalize:\n        return x.mean(0)\n    else:\n        return F.normalize(x.mean(0),dim=1)\n\n\nclass Similarity_Loss(nn.Module):\n    def __init__(self, ):\n        super().__init__()\n        pass\n\n    def forward(self, z_list, z_avg):\n        z_sim = 0\n        num_patch = len(z_list)\n        z_list = torch.stack(list(z_list), dim=0)\n        z_avg = z_list.mean(dim=0)\n        \n        z_sim = 0\n        for i in range(num_patch):\n            z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()\n            \n        z_sim = z_sim/num_patch\n        z_sim_out = z_sim.clone().detach()\n                \n        return -z_sim, z_sim_out\n    \ndef cal_TCR(z, criterion, num_patches):\n    z_list = z.chunk(num_patches,dim=0)\n    loss = 0\n    for i in range(num_patches):\n        loss += criterion(z_list[i])\n    loss = loss/num_patches\n    return loss\n\n######################\n## Prepare Training ##\n######################\ntorch.multiprocessing.set_sharing_strategy('file_system')\n\nif args.data == \"imagenet100\" or args.data == \"imagenet\":\n    train_dataset = load_dataset(\"imagenet\", train=True, num_patch = num_patches)\n    dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=8)\n\nelse:\n    train_dataset = load_dataset(args.data, train=True, num_patch = num_patches)\n    dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=16)\n\n\nuse_cuda = True\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\n    \n    \nnet = encoder(arch = args.arch)\nnet = nn.DataParallel(net)\nnet.cuda()\n\n\nopt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)\nopt = LARSWrapper(opt,eta=0.005,clip=True,exclude_bias_n_norm=True,)\n\nscaler = GradScaler()\nif args.data == \"imagenet-100\":\n    num_converge = (150000//args.bs)*args.epoch\nelse:\n    num_converge = (50000//args.bs)*args.epoch\n    \nscheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=num_converge, eta_min=0,last_epoch=-1)\n\n# Loss\ncontractive_loss = Similarity_Loss()\ncriterion = TotalCodingRate(eps=args.eps)\n\n\n##############\n## Training ##\n##############\ndef main():\n    for epoch in range(args.epoch):            \n        for step, (data, label) in tqdm(enumerate(dataloader)):\n            net.zero_grad()\n            opt.zero_grad()\n        \n            data = torch.cat(data, dim=0) \n            data = data.cuda()\n            z_proj = net(data)\n            \n            z_list = z_proj.chunk(num_patches, dim=0)\n            z_avg = chunk_avg(z_proj, num_patches)\n            \n            \n            #Contractive Loss\n            loss_contract, _ = contractive_loss(z_list, z_avg)\n            loss_TCR = cal_TCR(z_proj, criterion, num_patches)\n            \n            loss = args.patch_sim*loss_contract + args.tcr*loss_TCR\n          \n            loss.backward()\n            opt.step()\n            scheduler.step()\n            \n\n        model_dir = dir_name+\"/save_models/\"\n        if not os.path.exists(model_dir):\n            os.makedirs(model_dir)\n        torch.save(net.state_dict(), model_dir+str(epoch)+\".pt\")\n        \n    \n        print(\"At epoch:\", epoch, \"loss similarity is\", loss_contract.item(), \",loss TCR is:\", (loss_TCR).item(), \"and learning rate is:\", opt.param_groups[0]['lr'])\n       \n                \n\n\n# Press the green button in the gutter to run the script.\nif __name__ == '__main__':\n    main()\n\n# See PyCharm help at https://www.jetbrains.com/help/pycharm/\n"
  },
  {
    "path": "mcr/loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MCRGANloss(nn.Module):\n\n    def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000, mode=0):\n        super(MCRGANloss, self).__init__()\n\n        self.num_class = numclasses\n        self.train_mode = mode\n        self.gam1 = gam1\n        self.gam2 = gam2\n        self.gam3 = gam3\n        self.eps = eps\n\n    def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):\n\n        # t = time.time()\n        errD, empi = self.old_version(Z, Z_bar, real_label, ith_inner_loop, num_inner_loop)\n\n        return errD, empi\n\n    def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):\n\n        if self.train_mode == 2:\n            loss_z, _ = self.deltaR(Z, real_label, self.num_class)\n            assert num_inner_loop >= 2\n            if (ith_inner_loop + 1) % num_inner_loop != 0:\n                # print(f\"{ith_inner_loop + 1}/{num_inner_loop}\")\n                # print(\"calculate delta R(z)\")\n                return loss_z, None\n\n            loss_h, _ = self.deltaR(Z_bar, real_label, self.num_class)\n\n            empi = [loss_z, loss_h]\n            term3 = 0.\n            for i in range(self.num_class):\n                new_Z = torch.cat((Z[real_label == i], Z_bar[real_label == i]), 0)\n                new_label = torch.cat(\n                    (torch.zeros_like(real_label[real_label == i]),\n                     torch.ones_like(real_label[real_label == i]))\n                )\n                loss, em = self.deltaR(new_Z, new_label, 2)\n                term3 += loss\n            empi = empi + [term3]\n            errD = self.gam1 * loss_z + self.gam2 * loss_h + self.gam3 * term3\n\n        elif self.train_mode == 1:\n            print(\"has been dropped\")\n            raise NotImplementedError()\n\n        elif self.train_mode == 0:\n            new_Z = torch.cat((Z, Z_bar), 0)\n            new_label = torch.cat((torch.zeros_like(real_label), torch.ones_like(real_label)))\n            errD, empi = self.deltaR(new_Z, new_label, 2)\n        else:\n            raise ValueError()\n\n        return errD, empi\n\n    def debug(self, Z, Z_bar, real_label):\n\n        print(\"===========================\")\n\n    def compute_discrimn_loss(self, Z):\n        \"\"\"Theoretical Discriminative Loss.\"\"\"\n        d, n = Z.shape\n        I = torch.eye(d).to(Z.device)\n        scalar = d / (n * self.eps)\n        logdet = torch.logdet(I + scalar * Z @ Z.T)\n        return logdet / 2.\n\n    def compute_compress_loss(self, Z, Pi):\n        \"\"\"Theoretical Compressive Loss.\"\"\"\n        d, n = Z.shape\n        I = torch.eye(d).to(Z.device)\n        compress_loss = []\n        scalars = []\n        for j in range(Pi.shape[1]):\n            Z_ = Z[:, Pi[:, j] == 1]\n            trPi = Pi[:, j].sum() + 1e-8\n            scalar = d / (trPi * self.eps)\n            log_det = torch.logdet(I + scalar * Z_ @ Z_.T)\n            compress_loss.append(log_det)\n            scalars.append(trPi / (2 * n))\n        return compress_loss, scalars\n\n    def deltaR(self, Z, Y, num_classes):\n    \n        if num_classes is None:\n            num_classes = Y.max() + 1\n            \n        #print(\"classes:\", num_classes)\n\n        Pi = F.one_hot(Y, num_classes).to(Z.device)\n        discrimn_loss = self.compute_discrimn_loss(Z.T)\n        compress_loss, scalars = self.compute_compress_loss(Z.T, Pi)\n\n        compress_term = 0.\n        for z, s in zip(compress_loss, scalars):\n            compress_term += s * z\n        total_loss = discrimn_loss - compress_term\n\n        return -total_loss, (discrimn_loss, compress_term, compress_loss, scalars)\n\n    def gumb_compress_loss(self, Z, P):\n        d, n = Z.shape\n        I = torch.eye(d).to(Z.device)\n        compress_loss = 0.\n        for j in range(self.num_class):\n        \n            #P[:, j:j+1][P[:, j:j+1]<threshold] = 0 \n            \n            Z_ = Z * P[:, j:j+1]\n            trPi = P[:, j].sum() + 1e-8\n            scalar = d / (trPi * self.eps)\n            log_det = torch.logdet(I + scalar * Z_ @ Z_.T)\n            compress_loss += (trPi / (2 * n)) *log_det\n        return compress_loss\n\n    def pseudo_label_loss(self, Z, logits, thres = 1.4):\n    \n        logits = logits*thres\n\n        P = F.gumbel_softmax(logits)\n\n        discrimn_loss = self.compute_discrimn_loss(Z.T)\n        compress_loss = self.gumb_compress_loss(Z, P)\n        total_loss = discrimn_loss - compress_loss\n\n        return -total_loss, (discrimn_loss, compress_loss)"
  },
  {
    "path": "model/model.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nfrom torchvision.models import resnet18, resnet34, resnet50\n\nfrom .resnet import Resnet10CIFAR\n\ndef getmodel(arch):\n    \n    #backbone = resnet18()\n    \n    if arch == \"resnet18-cifar\":\n        backbone = resnet18()\n        backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) \n        backbone.maxpool = nn.Identity()\n        backbone.fc = nn.Identity()\n        return backbone, 512  \n    elif arch == \"resnet18-imagenet\":\n        backbone = resnet18()    \n        backbone.fc = nn.Identity()\n        return backbone, 512\n    elif arch == \"resnet18-tinyimagenet\":\n        backbone = resnet18()    \n        backbone.avgpool = nn.AdaptiveAvgPool2d(1)\n        backbone.fc = nn.Identity()\n        return backbone, 512\n    else:\n        raise NameError(\"{} not found in network architecture\".format(arch))\n  \n\nclass encoder(nn.Module): \n     def __init__(self,z_dim=1024,hidden_dim=4096, norm_p=2, arch = \"resnet18-cifar\"):\n        super().__init__()\n\n        backbone, feature_dim = getmodel(arch)\n        self.backbone = backbone\n        self.norm_p = norm_p\n        self.pre_feature = nn.Sequential(nn.Linear(feature_dim,hidden_dim),\n                                         nn.BatchNorm1d(hidden_dim),\n                                         nn.ReLU()\n                                        )\n        self.projection = nn.Sequential(nn.Linear(hidden_dim,hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim,z_dim))\n        \n          \n     def forward(self, x, is_test = False):\n         \n        feature = self.backbone(x)\n        feature = self.pre_feature(feature)\n        z = F.normalize(self.projection(feature),p=self.norm_p)\n\n        if is_test:\n            return z, feature\n        else:\n            return z\n\n   \n    "
  },
  {
    "path": "model/resnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision.models import resnet18, resnet34, resnet50\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n    def __init__(self, in_planes, planes, stride=1):\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,\n                               stride=stride, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion * planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion * planes,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion * planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.bn2(self.conv2(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n    \n    def __init__(self, in_planes, planes, stride=1):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=stride, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, self.expansion * planes,\n                               kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion * planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,\n                          stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion * planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = self.bn3(self.conv3(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, blocks_config, first_config, first_pool=False):\n        super(ResNet, self).__init__()\n        #format of first_config\n        [in_chan, chan, k, s] = first_config\n        self.in_planes = chan\n        self.conv1 = nn.Conv2d(in_chan, chan, kernel_size=k, stride=s,\n                               padding=k//2, bias=False)\n        self.bn1 = nn.BatchNorm2d(chan)\n        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if first_pool else nn.Identity()\n        self.layer1 = self._make_layer(block, blocks_config[0][0], blocks_config[0][1], stride=1)\n        self.layer2 = self._make_layer(block, blocks_config[1][0], blocks_config[1][1], stride=2)\n        self.layer3 = self._make_layer(block, blocks_config[2][0], blocks_config[2][1], stride=2)\n        self.layer4 = self._make_layer(block, blocks_config[3][0], blocks_config[3][1], stride=2)\n    \n        \n    def _make_layer(self, block, planes, num_blocks, stride):\n        strides = [stride] + [1] * (num_blocks - 1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n    \n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.pool(out)\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = self.layer4(out)\n        \n        \n        \n        feature = out.mean((2,3))\n        \n        return feature\n    \n\ndef Resnet10MNIST():\n    block = BasicBlock\n    blocks_config = [\n        [64,1],[128,1],[256,1],[512,1]\n    ]\n    first_config = [1,64,3,1]\n    return ResNet(block,blocks_config,first_config,first_pool=False)\n\ndef Resnet10CIFAR():\n    block = BasicBlock\n    blocks_config = [\n        [32,1],[64,1],[128,1],[256,1] \n    ]\n    first_config = [3,32,3,1]\n    return ResNet(block,blocks_config,first_config,first_pool=True)\n\ndef Resnet18imgs():\n    block = BasicBlock\n    blocks_config = [\n        [32,2],[64,2],[128,2],[256,2]\n    ]\n    first_config = [1,32,5,2]\n    return ResNet(block,blocks_config,first_config,first_pool=True)\n\ndef Resnet18CIFAR():\n    backbone = resnet18()\n    backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n    backbone.maxpool = nn.Identity()\n    backbone.fc = nn.Identity()\n    return backbone\n    \ndef Resnet18STL10():\n    block = BasicBlock\n    blocks_config = [\n        [64,2],[128,2],[256,2],[512,2]\n    ]\n    first_config = [3,64,5,2]\n    return ResNet(block,blocks_config,first_config,first_pool=True)\n\ndef Resnet34CIFAR():\n    backbone = resnet34()\n    backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n    backbone.maxpool = nn.Identity()\n    backbone.fc = nn.Identity()\n    return backbone\n\ndef Resnet34STL10():\n    block = BasicBlock\n    blocks_config = [\n        [64,3],[128,4],[256,6],[512,3]\n    ]\n    first_config = [3,64,5,2]\n    return ResNet(block,blocks_config,first_config,first_pool=True)"
  },
  {
    "path": "requirements.text",
    "content": "torch\ntorchvision\ntorchmetrics\nnumpy\ntqdm\nPillow\nmatplotlib\nscikit-learn"
  }
]