Full Code of tsb0601/EMP-SSL for AI

main 8fc3758617e0 cached
13 files
66.6 KB
17.3k tokens
122 symbols
1 requests
Download .txt
Repository: tsb0601/EMP-SSL
Branch: main
Commit: 8fc3758617e0
Files: 13
Total size: 66.6 KB

Directory structure:
gitextract_sphqay9s/

├── README.md
├── dataset/
│   ├── aug.py
│   ├── aug4img.py
│   └── datasets.py
├── evaluate.py
├── func.py
├── lars.py
├── loss.py
├── main.py
├── mcr/
│   └── loss.py
├── model/
│   ├── model.py
│   └── resnet.py
└── requirements.text

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch

[![arXiv](https://img.shields.io/badge/arXiv-2304.03977-b31b1b.svg)](https://arxiv.org/abs/2304.03977)


![Training Pipeline](pipeline.png)


Authors: Shengbang Tong*, Yubei Chen*, Yi Ma, Yann LeCun

## Introduction
This repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch." The paper introduces a simplistic but efficient self-supervised learning method called Extreme-Multi-Patch Self-Supervised-Learning (EMP-SSL). EMP-SSL significantly reduces the training epochs required for convergence by increasing the number of fix size image patches from each image instance.

## Preparing Training Data
Cifar10 and Cifar100 can be downloaded automatically in the script. ImageNet100 is a special subset of ImageNet. Details can be found in this [link](https://github.com/HobbitLong/CMC/issues/21).

## Getting Started
Current code implementation supports Cifar10, Cifar100 and ImageNet100.

To get started with the EMP-SSL implementation, follow these instructions:

### 1. Clone this repository
```bash
git clone https://github.com/tsb0601/emp-ssl.git
cd emp-ssl
``` 
### 2. Install required packages
```
pip install -r requirements.txt
```
### 3. Training

#### Reproducing 1-epoch results

|                    | CIFAR-10<br>1 Epoch | CIFAR-100<br>1 Epoch | Tiny ImageNet<br>1 epochs | ImageNet-100<br>1 epochs |
|--------------------|:----------------------:|:-----------------------:|:----------------------------:|:--------------------------:|
| EMP-SSL (1 Epoch)  |         0.842          |          0.585          |             0.381             |            0.585           |

For CIFAR10 or CIFAR100
```
python main.py --data cifar10 --epoch 2 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```
For ImageNet100
```
python main.py --data imagenet100 --epoch 2 --patch_sim 200 --arch 'resnet18-imagenet' --num_patches 20 --lr 0.3
```


#### Reproducing multi epochs results

|                      | CIFAR-10<br>1 Epoch | CIFAR-10<br>10 Epochs | CIFAR-10<br>30 Epochs | CIFAR-10<br>1000 Epochs | CIFAR-100<br>1 Epoch | CIFAR-100<br>10 Epochs | CIFAR-100<br>30 Epochs | CIFAR-100<br>1000 Epochs | Tiny ImageNet<br>10 Epochs | Tiny ImageNet<br>1000 Epochs |ImageNet-100<br>10 Epochs | ImageNet-100<br>400 Epochs |
|----------------------|:-------------------:|:---------------------:|:---------------------:|:-----------------------:|:--------------------:|:----------------------:|:----------------------:|:------------------------:| :------------------------:|:------------------------:|:------------------------:| :------------------------:|
| SimCLR               |        0.282        |         0.565         |         0.663         |          0.910          |         0.054        |         0.185          |         0.341          |          0.662           | - | 0.488 | - | 0.776
| BYOL                 |        0.249        |         0.489         |         0.684         |          0.926          |         0.043        |         0.150          |         0.349          |          0.708           | - | 0.510 | - | 0.802
| VICReg               |        0.406        |         0.697         |         0.781         |          0.921          |         0.079        |         0.319          |         0.479          |          0.685           | - | - | - | 0.792
| SwAV                 |        0.245        |         0.532         |         0.767         |          0.923          |         0.028        |         0.208          |         0.294          |          0.658           |- | - | - | 0.740
| ReSSL                |        0.245        |         0.256         |         0.525         |          0.914          |         0.033        |         0.122          |         0.247          |          0.674           |- | - | - | 0.769
| EMP-SSL (20 patches) |        0.806        |         0.907         |         0.931         |            -            |         0.551        |         0.678          |         0.724          |            -              | - | - | - | -
| EMP-SSL (200 patches)|        0.826*        |         0.915         |         0.934         |            -            |         0.577        |         0.701          |         0.733          |            -              | 0.515 | - | 0.789 | -

\* Here, we change learning rate schedule to decay in 30 epochs, so 1 epoch accuracy will be slightly lower than optimizing for 1-epoch training. 

Change num_patches here to change the number of patches used in EMP-SSL training.
```
python main.py --data cifar10 --epoch 30 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```



### 4. Evaluating
Because our model is trained with only fixed size image patches. To evaluate the performance, we adopt bag-of-features model from intra-instance VICReg paper. Change test_patches here to adjust number of patches used in bag-of-feature model for different GPUs.
```
python evaluate.py --model_path 'path to your evaluated model' --test_patches 128
```

## Acknowledgment
This repo is inspired by [MCR2](https://github.com/Ma-Lab-Berkeley/MCR2), [solo-learn](https://github.com/vturrisi/solo-learn) and [NMCE](https://github.com/zengyi-li/NMCE-release) repo.

## Citation
If you find this repository useful, please consider giving a star :star: and citation:

```
@article{tong2023empssl,
title={EMP-SSL: Towards Self-Supervised Learning in One Training Epoch},
author={Shengbang Tong and Yubei Chen and Yi Ma and Yann Lecun},
journal={arXiv preprint arXiv:2304.03977},
year={2023}
}
```


================================================
FILE: dataset/aug.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps


def load_transforms(name):
    """Load data transformations.
    
    Note:
        - Gaussian Blur is defined at the bottom of this file.
    
    """
    _name = name.lower()
    if _name == "cifar_sup":
        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
            transforms.RandomCrop(32, padding=8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])
        baseline_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
            transforms.ToTensor(),normalize])

    elif _name == "cifar_patch":
        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        aug_transform = transforms.Compose([
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.ToTensor(),
            normalize
        ])
        baseline_transform = transforms.Compose([
            transforms.ToTensor(), normalize])
        
    elif _name == "cifar_simclr_norm":
        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            normalize
        ])
        baseline_transform = transforms.Compose([
            transforms.ToTensor(),normalize])
    
    elif _name == "cifar_byol":
        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                    (32, 32),
                    scale=(0.2, 1.0),
                    interpolation=transforms.InterpolationMode.BICUBIC,
                ),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            normalize
        ])
        baseline_transform = transforms.Compose([
#             transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
            transforms.ToTensor(),normalize])

    else:
        raise NameError("{} not found in transform loader".format(name))
    return aug_transform, baseline_transform


class Solarization:
    """Solarization as a callable object."""

    def __call__(self, img: Image) -> Image:
        """Applies solarization to an input image.

        Args:
            img (Image): an image in the PIL.Image format.

        Returns:
            Image: a solarized image.
        """

        return ImageOps.solarize(img)

class GBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if np.random.rand() < self.p:
            sigma = np.random.rand() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class ContrastiveLearningViewGenerator(object):
    def __init__(self, num_patch = 4):
    
        self.num_patch = num_patch
      
    def __call__(self, x):
    
    
        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GBlur(p=0.1),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.ToTensor(),  
            normalize
        ])
        augmented_x = [aug_transform(x) for i in range(self.num_patch)]
     
        return augmented_x


================================================
FILE: dataset/aug4img.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
from torchvision.transforms import InterpolationMode



class Solarization:
    """Solarization as a callable object."""

    def __call__(self, img: Image) -> Image:
        """Applies solarization to an input image.

        Args:
            img (Image): an image in the PIL.Image format.

        Returns:
            Image: a solarized image.
        """

        return ImageOps.solarize(img)

class GBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if np.random.rand() < self.p:
            sigma = np.random.rand() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class ContrastiveLearningViewGenerator(object):
    def __init__(self, num_patch = 4):

        self.num_patch = num_patch
        
    def __call__(self, x):
        aug_transform =  transforms.Compose([
            transforms.RandomResizedCrop(
                224, scale=(0.25, 0.25), interpolation=InterpolationMode.BICUBIC
            ),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GBlur(p=0.1),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            
        ])

        
        augmented_x = [aug_transform(x) for i in range(self.num_patch)]
       
        
        return augmented_x
 
        


================================================
FILE: dataset/datasets.py
================================================
import os
import numpy as np
import torchvision

def load_dataset(data_name, train=True, num_patch = 4, path="./data/"):
    """Loads a dataset for training and testing. If augmentloader is used, transform should be None.
    
    Parameters:
        data_name (str): name of the dataset
        transform_name (torchvision.transform): name of transform to be applied (see aug.py)
        use_baseline (bool): use baseline transform or augmentation transform
        train (bool): load training set or not
        contrastive (bool): whether to convert transform to multiview augmentation for contrastive learning.
        n_views (bool): number of views for contrastive learning
        path (str): path to dataset base path

    Returns:
        dataset (torch.data.dataset)
    """
    _name = data_name.lower()
    if _name == "imagenet":
        from .aug4img import ContrastiveLearningViewGenerator
    else:
        from .aug import ContrastiveLearningViewGenerator
      
    
    transform = ContrastiveLearningViewGenerator(num_patch = num_patch)
        
    if _name == "cifar10":
        trainset = torchvision.datasets.CIFAR10(root=os.path.join(path, "CIFAR10"), train=train, download=True, transform=transform)
        trainset.num_classes = 10
    elif _name == "cifar100":
        trainset = torchvision.datasets.CIFAR100(root=os.path.join(path, "CIFAR100"), train=train, download=True, transform=transform)
        trainset.num_classes = 100
    elif _name == "imagenet":
        if train:
            trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/train100/",transform=transform)
            #trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/train/",transform=transform)
        else:
            trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/val100/",transform=transform)
            #trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/val/",transform=transform)
        trainset.num_classes = 200  
        
    else:
        raise NameError("{} not found in trainset loader".format(_name))
    return trainset

def sparse2coarse(targets):
    """CIFAR100 Coarse Labels. """
    coarse_targets = [ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  3, 14,  9, 18,  7, 11,  3,
                       9,  7, 11,  6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  0, 11,  1, 10,
                      12, 14, 16,  9, 11,  5,  5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 16,
                       4, 17,  4,  2,  0, 17,  4, 18, 17, 10,  3,  2, 12, 12, 16, 12,  1,
                       9, 19,  2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 16, 19,  2,  4,  6,
                      19,  5,  5,  8, 19, 18,  1,  2, 15,  6,  0, 17,  8, 14, 13]
    return np.array(coarse_targets)[targets]

================================================
FILE: evaluate.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import torch
import numpy as np
from func import WeightedKNNClassifier, linear

######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Evaluation')

parser.add_argument('--test_patches', type=int, default=128,
                    help='number of patches used in testing (default: 128)')  

parser.add_argument('--data', type=str, default="cifar10",
                    help='dataset (default: cifar10)')  
parser.add_argument('--arch', type=str, default="resnet18-cifar",
                    help='network architecture (default: resnet18-cifar)')

parser.add_argument('--lr', type=float, default=0.03,
                    help='learning rate for linear eval (default: 0.03)')        
parser.add_argument('--linear', type=bool, default=True,
                    help='use linear eval or not')
parser.add_argument('--knn', help='evaluate using kNN measuring cosine similarity', action='store_true')
parser.add_argument('--model_path', type=str, default="",
                    help='model directory for eval')

            
args = parser.parse_args()















######################
## Testing Accuracy ##
######################
test_patches = args.test_patches

def compute_accuracy(y_pred, y_true):
    """Compute accuracy by counting correct classification. """
    assert y_pred.shape == y_true.shape
    return 1 - np.count_nonzero(y_pred - y_true) / y_true.size

knn_classifier = WeightedKNNClassifier()


def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)


def test(net, train_loader, test_loader):
    
    train_z_full_list, train_y_list, test_z_full_list, test_y_list = [], [], [], []
    
    with torch.no_grad():
        for x, y in tqdm(train_loader):

            x = torch.cat(x, dim = 0)
            
            z_proj, z_pre = net(x, is_test=True)

            z_pre = chunk_avg(z_pre, test_patches)
            z_pre = z_pre.detach().cpu()
            
            
            train_z_full_list.append(z_pre)
            
            
            knn_classifier.update(train_features = z_pre, train_targets = y)

            train_y_list.append(y)
                
        for x, y in tqdm(test_loader):
            x = torch.cat(x, dim = 0)
            
            z_proj, z_pre = net(x, is_test=True)

            z_pre = chunk_avg(z_pre, test_patches)
            z_pre = z_pre.detach().cpu()
           
            test_z_full_list.append(z_pre)
       
            knn_classifier.update(test_features = z_pre, test_targets = y)

            test_y_list.append(y)
                
            
    train_features_full, train_labels, test_features_full, test_labels = torch.cat(train_z_full_list,dim=0), torch.cat(train_y_list,dim=0), torch.cat(test_z_full_list,dim=0), torch.cat(test_y_list,dim=0)
   
    if args.data == "cifar10":
        num_classes = 10
    elif args.data == "cifar100":
        num_classes = 100
    elif args.data == "tinyimagenet200":
        num_classes = 200
    elif args.data == "imagenet100":
        num_classes = 100
    elif args.data == "imagenet":
        num_classes = 1000
        
    if args.linear:
        print("Using Linear Eval to evaluate accuracy")
        linear(train_features_full, train_labels, test_features_full, test_labels, lr=args.lr, num_classes = num_classes)
    
    if args.knn:
        print("Using KNN to evaluate accuracy")
        top1, top5 = knn_classifier.compute()
        print("KNN (top1/top5):", top1, top5)
    
def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)


torch.multiprocessing.set_sharing_strategy('file_system')


#Get Dataset
if args.data == "imagenet100" or args.data == "imagenet":
        
    memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
    memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)

    test_data = load_dataset(args.data, train=False, num_patch = test_patches)
    test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)

else:
    memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
    memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)

    test_data = load_dataset(args.data, train=False, num_patch = test_patches)
    test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)

# Load Model and Checkpoint
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
save_dict = torch.load(args.model_path)
net.load_state_dict(save_dict,strict=False)
net.cuda()
net.eval()
test(net, memory_loader, test_loader)





================================================
FILE: func.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import torchvision
# import torch.nn


from torch import nn, optim
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader



from typing import Tuple


import torch.nn.functional as F
from torchmetrics.metric import Metric


class WeightedKNNClassifier(Metric):
    def __init__(
        self,
        k: int = 20,
        T: float = 0.07,
        max_distance_matrix_size: int = int(5e6),
        distance_fx: str = "cosine",
        epsilon: float = 0.00001,
        dist_sync_on_step: bool = False,
    ):
        """Implements the weighted k-NN classifier used for evaluation.
        Args:
            k (int, optional): number of neighbors. Defaults to 20.
            T (float, optional): temperature for the exponential. Only used with cosine
                distance. Defaults to 0.07.
            max_distance_matrix_size (int, optional): maximum number of elements in the
                distance matrix. Defaults to 5e6.
            distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
                "euclidean". Defaults to "cosine".
            epsilon (float, optional): Small value for numerical stability. Only used with
                euclidean distance. Defaults to 0.00001.
            dist_sync_on_step (bool, optional): whether to sync distributed values at every
                step. Defaults to False.
        """

        super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)

        self.k = k
        self.T = T
        self.max_distance_matrix_size = max_distance_matrix_size
        self.distance_fx = distance_fx
        self.epsilon = epsilon

        self.add_state("train_features", default=[], persistent=False)
        self.add_state("train_targets", default=[], persistent=False)
        self.add_state("test_features", default=[], persistent=False)
        self.add_state("test_targets", default=[], persistent=False)

    def update(
        self,
        train_features: torch.Tensor = None,
        train_targets: torch.Tensor = None,
        test_features: torch.Tensor = None,
        test_targets: torch.Tensor = None,
    ):
        """Updates the memory banks. If train (test) features are passed as input, the
        corresponding train (test) targets must be passed as well.
        Args:
            train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
            train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
            test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
            test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
        """
        assert (train_features is None) == (train_targets is None)
        assert (test_features is None) == (test_targets is None)

        if train_features is not None:
            assert train_features.size(0) == train_targets.size(0)
            self.train_features.append(train_features.detach())
            self.train_targets.append(train_targets.detach())

        if test_features is not None:
            assert test_features.size(0) == test_targets.size(0)
            self.test_features.append(test_features.detach())
            self.test_targets.append(test_targets.detach())

    def set_tk(self, T, k):
        self.T = T
        self.k = k
        
    @torch.no_grad()
    def compute(self) -> Tuple[float]:
        """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
        the weight is computed using the exponential of the temperature scaled cosine
        distance of the samples. If euclidean distance is selected, the weight corresponds
        to the inverse of the euclidean distance.
        Returns:
            Tuple[float]: k-NN accuracy @1 and @5.
        """
        
        #print(self.T, self.k)

        train_features = torch.cat(self.train_features)
        train_targets = torch.cat(self.train_targets)
        test_features = torch.cat(self.test_features)
        test_targets = torch.cat(self.test_targets)

        if self.distance_fx == "cosine":
            train_features = F.normalize(train_features)
            test_features = F.normalize(test_features)

        num_classes = torch.unique(test_targets).numel()
        num_train_images = train_targets.size(0)
        num_test_images = test_targets.size(0)
        num_train_images = train_targets.size(0)
        chunk_size = min(
            max(1, self.max_distance_matrix_size // num_train_images),
            num_test_images,
        )
        k = min(self.k, num_train_images)

        top1, top5, total = 0.0, 0.0, 0
        retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
        for idx in range(0, num_test_images, chunk_size):
            # get the features for test images
            features = test_features[idx : min((idx + chunk_size), num_test_images), :]
            targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
            batch_size = targets.size(0)

            # calculate the dot product and compute top-k neighbors
            if self.distance_fx == "cosine":
                similarities = torch.mm(features, train_features.t())
            elif self.distance_fx == "euclidean":
                similarities = 1 / (torch.cdist(features, train_features) + self.epsilon)
            else:
                raise NotImplementedError

            similarities, indices = similarities.topk(k, largest=True, sorted=True)
            candidates = train_targets.view(1, -1).expand(batch_size, -1)
            retrieved_neighbors = torch.gather(candidates, 1, indices)

            retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
            retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)

            if self.distance_fx == "cosine":
                similarities = similarities.clone().div_(self.T).exp_()

            probs = torch.sum(
                torch.mul(
                    retrieval_one_hot.view(batch_size, -1, num_classes),
                    similarities.view(batch_size, -1, 1),
                ),
                1,
            )
            _, predictions = probs.sort(1, True)

            # find the predictions that match the target
            correct = predictions.eq(targets.data.view(-1, 1))
            top1 = top1 + correct.narrow(1, 0, 1).sum().item()
            top5 = (
                top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()
            )  # top5 does not make sense if k < 5
            total += targets.size(0)

        top1 = top1 * 100.0 / total
        top5 = top5 * 100.0 / total

        self.reset()

        return top1, top5







def linear(train_features, train_labels, test_features, test_labels, lr=0.0075, num_classes = 100):


   
    
    train_data = tensor_dataset(train_features,train_labels)
    test_data = tensor_dataset(test_features,test_labels)
    train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True, num_workers=2)
    test_loader = DataLoader(test_data, batch_size=100, shuffle=True, drop_last=False, num_workers=2)
    
    LL = nn.Linear(train_features.shape[1],num_classes)
    optimizer = torch.optim.SGD(LL.parameters(), lr=lr, momentum=0.9, weight_decay=5e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
    
    criterion = torch.nn.CrossEntropyLoss()
    
    test_acc_list = []
    for epoch in range(100):
        top1_train_accuracy = 0
        for counter, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch
            y_batch = y_batch
            
            logits = LL(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        scheduler.step() 

        top1_train_accuracy /= (counter + 1)

        top1_accuracy = 0
        top5_accuracy = 0
        for counter, (x_batch, y_batch) in enumerate(test_loader):
            x_batch = x_batch
            y_batch = y_batch

            logits = LL(x_batch)

            top1, top5 = accuracy(logits, y_batch, topk=(1,5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]

        top1_accuracy /= (counter + 1)
        top5_accuracy /= (counter + 1)
        
        test_acc_list.append(top1_accuracy)
        
        print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")
    acc_vect = torch.tensor(test_acc_list)
    print('best linear test acc {}, last acc {}'.format(acc_vect.max().item(),acc_vect[-1].item()))
        


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
            
class tensor_dataset(data.Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y
        self.length = x.shape[0]
    
    def __getitem__(self,indx):
        return self.x[indx], self.y[indx]
    
    def __len__(self):
        return self.length




def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0.,gamma_max=1.0):
    warmup_start = total_epoch - warmup_epoch
    warmup_end = total_epoch
    
    if warmup_start < epoch<=warmup_end:
        loss_fn.gamma = ((epoch - warmup_start)/(warmup_end - warmup_start))*(gamma_max - gamma_min) + gamma_min
    else:
        loss_fn.gamma = gamma_min

def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):
    if epoch<warmup_epoch:
        optimizer.param_groups[0]['lr'] = base_lr*min(1.,(epoch+1)/warmup_epoch)
        
        
def marginal_H(logits):
    bs = torch.tensor(logits.shape[0]).float()
    logps = torch.log_softmax(logits,dim=1)
    marginal_p = torch.logsumexp(logps - bs.log(),dim=0)
    H = (marginal_p.exp()*(-marginal_p)).sum()*(1.4426950)
    return H

def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)

def cluster_match(cluster_mtx,label_mtx,n_classes=10,print_result=True):
    #verified to be consistent to optimimal assignment problem based algorithm
    cluster_indx = list(cluster_mtx.unique())
    assigned_label_list = []
    assigned_count = []
    while (len(assigned_label_list)<=n_classes) and len(cluster_indx)>0:
        max_label_list = []
        max_count_list = []
        for indx in cluster_indx:
            #calculate highest number of matchs
            mask = cluster_mtx==indx
            label_elements, counts = label_mtx[mask].unique(return_counts=True)
            for assigned_label in assigned_label_list:
                counts[label_elements==assigned_label] = 0
            max_count_list.append(counts.max())
            max_label_list.append(label_elements[counts.argmax()])

        max_label = torch.stack(max_label_list)
        max_count = torch.stack(max_count_list)
        assigned_label_list.append(max_label[max_count.argmax()])
        assigned_count.append(max_count.max())
        cluster_indx.pop(max_count.argmax())
    total_correct = torch.tensor(assigned_count).sum().item()
    total_sample = cluster_mtx.shape[0]
    acc = total_correct/total_sample
    if print_result:
        print('{}/{} ({}%) correct'.format(total_correct,total_sample,acc*100))
    else:
        return total_correct, total_sample, acc

def cluster_merge_match(cluster_mtx,label_mtx,print_result=True):
    cluster_indx = list(cluster_mtx.unique())
    n_correct = 0
    for cluster_id in cluster_indx:
        label_elements, counts = label_mtx[cluster_mtx==cluster_id].unique(return_counts=True)
        n_correct += counts.max()
    total_sample = len(cluster_mtx)
    acc = n_correct.item()/total_sample
    if print_result:
        print('{}/{} ({}%) correct'.format(n_correct,total_sample,acc*100))
    else:
        return n_correct, total_sample, acc

    
def cluster_acc(test_loader,net,device,print_result=False,save_name_img='cluster_img',save_name_fig='pca_figure'):
    cluster_list = []
    label_list = []
    x_list = []
    z_list = []
    net.eval()
    for x, y in test_loader:
        with torch.no_grad():
            x, y = x.float().to(device), y.to(device)
            z, logit = net(x)
            if logit.sum() == 0:
                logit += torch.randn_like(logit)
            cluster_list.append(logit.max(dim=1)[1].cpu())
            label_list.append(y.cpu())
            x_list.append(x.cpu())
            z_list.append(z.cpu())
    net.train()
    cluster_mtx = torch.cat(cluster_list,dim=0)
    label_mtx = torch.cat(label_list,dim=0)
    x_mtx = torch.cat(x_list,dim=0)
    z_mtx = torch.cat(z_list,dim=0)
    _, _, acc_single = cluster_match(cluster_mtx,label_mtx,n_classes=label_mtx.max()+1,print_result=False)
    _, _, acc_merge = cluster_merge_match(cluster_mtx,label_mtx,print_result=False)
    NMI = normalized_mutual_info_score(label_mtx.numpy(),cluster_mtx.numpy())
    ARI = adjusted_rand_score(label_mtx.numpy(),cluster_mtx.numpy())
    if print_result:
        print('cluster match acc {}, cluster merge match acc {}, NMI {}, ARI {}'.format(acc_single,acc_merge,NMI,ARI))
    
    save_name_img += '_acc'+ str(acc_single)[2:5]
    save_cluster_imgs(cluster_mtx,x_mtx,save_name_img)
    save_latent_pca_figure(z_mtx,cluster_mtx,save_name_fig)
    
    return acc_single, acc_merge, NMI, ARI
    
def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):
    cluster_indexs, counts = cluster_mtx.unique(return_counts=True)
    x_list = []
    counts_list = []
    for i, c_indx in enumerate(cluster_indexs):
        if counts[i]>npercluster:
            x_list.append(x_mtx[cluster_mtx==c_indx,:,:,:])
            counts_list.append(counts[i])

    n_clusters = len(counts_list)
    fig, ax = plt.subplots(n_clusters,1,dpi=80,figsize=(1.2*n_clusters, 3*n_clusters))
    for i, ax in enumerate(ax):
        img = torchvision.utils.make_grid(x_list[i][:npercluster],nrow=npercluster//5,normalize=True)
        ax.imshow(img.permute(1,2,0))
        ax.set_axis_off()

        ax.set_title('Cluster with {} images'.format(counts_list[i]))
    
    fig.savefig(save_name+'.pdf')
    plt.close(fig)
    
def save_latent_pca_figure(z_mtx,cluster_mtx,save_name):
    _, s_z_all, _ = z_mtx.svd()
    cluster_n = []
    cluster_s = []
    for cluster_indx in cluster_mtx.unique():
        _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
        cluster_n.append((cluster_mtx==cluster_indx).sum().item())
        cluster_s.append(s_cluster/s_cluster.max())

    #make plot
    fig, ax = plt.subplots(1,2,figsize=(9, 3))
    ax[0].plot(s_z_all)
    for i, s_curve in enumerate(cluster_s):
        ax[1].plot(s_curve,label=cluster_n[i])
    ax[1].set_xlim(xmin=0,xmax=20)
    ax[1].legend()
    fig.savefig(save_name +'.pdf')
    plt.close(fig)
    
def analyze_latent(z_mtx,cluster_mtx):
    _, s_z_all, _ = z_mtx.svd()
    cluster_n = []
    cluster_s = []
    cluster_d = []
    for cluster_indx in cluster_mtx.unique():
        _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
        s_cluster = s_cluster/s_cluster.max()
        cluster_n.append((cluster_mtx==cluster_indx).sum().item())
        cluster_s.append(s_cluster)
#         print(list(cluster_s))
        print(s_cluster)
#         s_diff = s_cluster[:-1] - s_cluster[1:]
#         cluster_d.append(s_diff.max(0)[1])
        cluster_d.append((s_cluster>0.01).sum())
    for i in range(len(cluster_n)):
        print('subspace {}, dimension {}, samples {}'.format(i,cluster_d[i],cluster_n[i]))

================================================
FILE: lars.py
================================================
import torch
import torch.optim as optim
from torch.optim.optimizer import Optimizer, required

class LARS(Optimizer):
    """
    Layer-wise adaptive rate scaling
    - Converted from Tensorflow to Pytorch from:
    https://github.com/google-research/simclr/blob/master/lars_optimizer.py
    - Based on:
    https://github.com/noahgolmant/pytorch-lars
    params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): base learning rate (\gamma_0)
        lr (int): Length / Number of layers we want to apply weight decay, else do not compute
        momentum (float, optional): momentum factor (default: 0.9)
        use_nesterov (bool, optional): flag to use nesterov momentum (default: False)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
            ("\beta")
        eta (float, optional): LARS coefficient (default: 0.001)
    - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
    - Large Batch Training of Convolutional Networks:
        https://arxiv.org/abs/1708.03888
    """

    def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001):

        self.epoch = 0
        defaults = dict(
            lr=lr,
            momentum=momentum,
            use_nesterov=use_nesterov,
            weight_decay=weight_decay,
            classic_momentum=classic_momentum,
            eta=eta,
            len_reduced=len_reduced
        )

        super(LARS, self).__init__(params, defaults)
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.use_nesterov = use_nesterov
        self.classic_momentum = classic_momentum
        self.eta = eta
        self.len_reduced = len_reduced

    def step(self, epoch=None, closure=None):

        loss = None

        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            eta = group['eta']
            learning_rate = group['lr']

            # TODO: Hacky
            counter = 0
            for p in group['params']:
                if p.grad is None:
                    continue

                param = p.data
                grad = p.grad.data

                param_state = self.state[p]

                # TODO: This really hacky way needs to be improved.
                # Note Excluded are passed at the end of the list to are ignored
                if counter < self.len_reduced:
                    grad += self.weight_decay * param

                # Create parameter for the momentum
                if "momentum_var" not in param_state:
                    next_v = param_state["momentum_var"] = torch.zeros_like(
                        p.data
                    )
                else:
                    next_v = param_state["momentum_var"]

                if self.classic_momentum:
                    trust_ratio = 1.0

                    # TODO: implementation of layer adaptation
                    w_norm = torch.norm(param)
                    g_norm = torch.norm(grad)

                    device = g_norm.get_device()

                    trust_ratio = torch.where(w_norm.ge(0), torch.where(
                        g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)),
                                              torch.Tensor([1.0]).to(device)).item()

                    scaled_lr = learning_rate * trust_ratio
                    
                    grad_scaled = scaled_lr*grad
                    next_v.mul_(momentum).add_(grad_scaled)

                    if self.use_nesterov:
                        update = (self.momentum * next_v) + (scaled_lr * grad)
                    else:
                        update = next_v

                    p.data.add_(-update)

                # Not classic_momentum
                else:

                    next_v.mul_(momentum).add_(grad)

                    if self.use_nesterov:
                        update = (self.momentum * next_v) + (grad)

                    else:
                        update = next_v

                    trust_ratio = 1.0

                    # TODO: implementation of layer adaptation
                    w_norm = torch.norm(param)
                    v_norm = torch.norm(update)

                    device = v_norm.get_device()

                    trust_ratio = torch.where(w_norm.ge(0), torch.where(
                        v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)),
                                              torch.Tensor([1.0]).to(device)).item()

                    scaled_lr = learning_rate * trust_ratio

                    p.data.add_(-scaled_lr * update)

                counter += 1

        return loss
    
#LARSWrapper from solo-learn repo...
class LARSWrapper:
    def __init__(
        self,
        optimizer: Optimizer,
        eta: float = 1e-3,
        clip: bool = False,
        eps: float = 1e-8,
        exclude_bias_n_norm: bool = False,
    ):
        """Wrapper that adds LARS scheduling to any optimizer.
        This helps stability with huge batch sizes.

        Args:
            optimizer (Optimizer): torch optimizer.
            eta (float, optional): trust coefficient. Defaults to 1e-3.
            clip (bool, optional): clip gradient values. Defaults to False.
            eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
            exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
                Defaults to False.
        """

        self.optim = optimizer
        self.eta = eta
        self.eps = eps
        self.clip = clip
        self.exclude_bias_n_norm = exclude_bias_n_norm

        # transfer optim methods
        self.state_dict = self.optim.state_dict
        self.load_state_dict = self.optim.load_state_dict
        self.zero_grad = self.optim.zero_grad
        self.add_param_group = self.optim.add_param_group

        self.__setstate__ = self.optim.__setstate__  # type: ignore
        self.__getstate__ = self.optim.__getstate__  # type: ignore
        self.__repr__ = self.optim.__repr__  # type: ignore

    @property
    def defaults(self):
        return self.optim.defaults

    @defaults.setter
    def defaults(self, defaults):
        self.optim.defaults = defaults

    @property  # type: ignore
    def __class__(self):
        return Optimizer

    @property
    def state(self):
        return self.optim.state

    @state.setter
    def state(self, state):
        self.optim.state = state

    @property
    def param_groups(self):
        return self.optim.param_groups

    @param_groups.setter
    def param_groups(self, value):
        self.optim.param_groups = value

    @torch.no_grad()
    def step(self, closure=None):
        weight_decays = []

        for group in self.optim.param_groups:
            weight_decay = group.get("weight_decay", 0)
            weight_decays.append(weight_decay)

            # reset weight decay
            group["weight_decay"] = 0

            # update the parameters
            for p in group["params"]:
                if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
                    self.update_p(p, group, weight_decay)

        # update the optimizer
        self.optim.step(closure=closure)

        # return weight decay control to optimizer
        for group_idx, group in enumerate(self.optim.param_groups):
            group["weight_decay"] = weight_decays[group_idx]

    def update_p(self, p, group, weight_decay):
        # calculate new norms
        p_norm = torch.norm(p.data)
        g_norm = torch.norm(p.grad.data)

        if p_norm != 0 and g_norm != 0:
            # calculate new lr
            new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)

            # clip lr
            if self.clip:
                new_lr = min(new_lr / group["lr"], 1)

            # update params with clipped lr
            p.grad.data += weight_decay * p.data
            p.grad.data *= new_lr

================================================
FILE: loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class contrastive_loss(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    def forward(self,x,labels):
        #this function assums that positive logit is always the first element.
        #Which is true here
        loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1)
        return loss.mean()

class SimCLR(nn.Module):
    def __init__(self,temperature=0.5,n_views=2,contrastive=False):
        super(SimCLR,self).__init__()
        self.temp = temperature
        self.n_views = n_views
        
        if contrastive:
            self.criterion = contrastive_loss()
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
        
    def info_nce_loss(self,X):
        
        bs, n_dim = X.shape
        bs = int(bs/self.n_views)
        device = X.device
        
        
        labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        similarity_matrix = torch.matmul(X, X.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
        
        logits = logits / self.temp
        return logits, labels
        
    def forward(self,X):
        logits, labels = self.info_nce_loss(X)
        loss = self.criterion(logits, labels)
        return loss

class Z_loss(nn.Module):
    def __init__(self,):
        super().__init__()
        pass
        
    def forward(self,z):
        z_list = z.chunk(2,dim=0)
        z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean()
        z_sim_out = z_sim.clone().detach()
        return -z_sim, z_sim_out

class TotalCodingRate(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)

class MaximalCodingRateReduction(torch.nn.Module):
    def __init__(self, eps=0.01, gamma=1):
        super(MaximalCodingRateReduction, self).__init__()
        self.eps = eps
        self.gamma = gamma
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def compute_compress_loss(self, W, Pi):
        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p,device=W.device).expand((k,p,p))
        trPi = Pi.sum(2) + 1e-8
        scale = (p/(trPi*self.eps)).view(k,1,1)
        
        W = W.view((1,p,m))
        log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2)))
        compress_loss = (trPi.squeeze()*log_det/(2*m)).sum()
        return compress_loss
        
    def forward(self, X, Y, num_classes=None):
        #This function support Y as label integer or membership probablity.
        if len(Y.shape)==1:
            #if Y is a label vector
            if num_classes is None:
                num_classes = Y.max() + 1
            Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device)
            for indx, label in enumerate(Y):
                Pi[label,0,indx] = 1
        else:
            #if Y is a probility matrix
            if num_classes is None:
                num_classes = Y.shape[1]
            Pi = Y.T.reshape((num_classes,1,-1))
            
        W = X.T
        discrimn_loss = self.compute_discrimn_loss(W)
        compress_loss = self.compute_compress_loss(W, Pi)
 
        total_loss = - discrimn_loss + self.gamma*compress_loss
        return total_loss, [discrimn_loss.item(), compress_loss.item()]

================================================
FILE: main.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as FF
from tqdm import tqdm
import torch
from torchvision.datasets import CIFAR10
from loss import TotalCodingRate
from func import chunk_avg
from lars import LARS, LARSWrapper
from func import WeightedKNNClassifier
import torch.optim.lr_scheduler as lr_scheduler
from torch.cuda.amp import GradScaler, autocast

######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Unsupervised Learning')

parser.add_argument('--patch_sim', type=int, default=200,
                    help='coefficient of cosine similarity (default: 200)')
parser.add_argument('--tcr', type=int, default=1,
                    help='coefficient of tcr (default: 1)')
parser.add_argument('--num_patches', type=int, default=100,
                    help='number of patches used in EMP-SSL (default: 100)')
parser.add_argument('--arch', type=str, default="resnet18-cifar",
                    help='network architecture (default: resnet18-cifar)')
parser.add_argument('--bs', type=int, default=100,
                    help='batch size (default: 100)')
parser.add_argument('--lr', type=float, default=0.3,
                    help='learning rate (default: 0.3)')        
parser.add_argument('--eps', type=float, default=0.2,
                    help='eps for TCR (default: 0.2)') 
parser.add_argument('--msg', type=str, default="NONE",
                    help='additional message for description (default: NONE)')     
parser.add_argument('--dir', type=str, default="EMP-SSL-Training",
                    help='directory name (default: EMP-SSL-Training)')     
parser.add_argument('--data', type=str, default="cifar10",
                    help='data (default: cifar10)')          
parser.add_argument('--epoch', type=int, default=30,
                    help='max number of epochs to finish (default: 30)')  

args = parser.parse_args()

print(args)

num_patches = args.num_patches
dir_name = f"./logs/{args.dir}/patchsim{args.patch_sim}_numpatch{args.num_patches}_bs{args.bs}_lr{args.lr}_{args.msg}"



#####################
## Helper Function ##
#####################

def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)


class Similarity_Loss(nn.Module):
    def __init__(self, ):
        super().__init__()
        pass

    def forward(self, z_list, z_avg):
        z_sim = 0
        num_patch = len(z_list)
        z_list = torch.stack(list(z_list), dim=0)
        z_avg = z_list.mean(dim=0)
        
        z_sim = 0
        for i in range(num_patch):
            z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
            
        z_sim = z_sim/num_patch
        z_sim_out = z_sim.clone().detach()
                
        return -z_sim, z_sim_out
    
def cal_TCR(z, criterion, num_patches):
    z_list = z.chunk(num_patches,dim=0)
    loss = 0
    for i in range(num_patches):
        loss += criterion(z_list[i])
    loss = loss/num_patches
    return loss

######################
## Prepare Training ##
######################
torch.multiprocessing.set_sharing_strategy('file_system')

if args.data == "imagenet100" or args.data == "imagenet":
    train_dataset = load_dataset("imagenet", train=True, num_patch = num_patches)
    dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=8)

else:
    train_dataset = load_dataset(args.data, train=True, num_patch = num_patches)
    dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=16)


use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
    
    
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
net.cuda()


opt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
opt = LARSWrapper(opt,eta=0.005,clip=True,exclude_bias_n_norm=True,)

scaler = GradScaler()
if args.data == "imagenet-100":
    num_converge = (150000//args.bs)*args.epoch
else:
    num_converge = (50000//args.bs)*args.epoch
    
scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=num_converge, eta_min=0,last_epoch=-1)

# Loss
contractive_loss = Similarity_Loss()
criterion = TotalCodingRate(eps=args.eps)


##############
## Training ##
##############
def main():
    for epoch in range(args.epoch):            
        for step, (data, label) in tqdm(enumerate(dataloader)):
            net.zero_grad()
            opt.zero_grad()
        
            data = torch.cat(data, dim=0) 
            data = data.cuda()
            z_proj = net(data)
            
            z_list = z_proj.chunk(num_patches, dim=0)
            z_avg = chunk_avg(z_proj, num_patches)
            
            
            #Contractive Loss
            loss_contract, _ = contractive_loss(z_list, z_avg)
            loss_TCR = cal_TCR(z_proj, criterion, num_patches)
            
            loss = args.patch_sim*loss_contract + args.tcr*loss_TCR
          
            loss.backward()
            opt.step()
            scheduler.step()
            

        model_dir = dir_name+"/save_models/"
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(net.state_dict(), model_dir+str(epoch)+".pt")
        
    
        print("At epoch:", epoch, "loss similarity is", loss_contract.item(), ",loss TCR is:", (loss_TCR).item(), "and learning rate is:", opt.param_groups[0]['lr'])
       
                


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    main()

# See PyCharm help at https://www.jetbrains.com/help/pycharm/


================================================
FILE: mcr/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class MCRGANloss(nn.Module):

    def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000, mode=0):
        super(MCRGANloss, self).__init__()

        self.num_class = numclasses
        self.train_mode = mode
        self.gam1 = gam1
        self.gam2 = gam2
        self.gam3 = gam3
        self.eps = eps

    def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):

        # t = time.time()
        errD, empi = self.old_version(Z, Z_bar, real_label, ith_inner_loop, num_inner_loop)

        return errD, empi

    def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):

        if self.train_mode == 2:
            loss_z, _ = self.deltaR(Z, real_label, self.num_class)
            assert num_inner_loop >= 2
            if (ith_inner_loop + 1) % num_inner_loop != 0:
                # print(f"{ith_inner_loop + 1}/{num_inner_loop}")
                # print("calculate delta R(z)")
                return loss_z, None

            loss_h, _ = self.deltaR(Z_bar, real_label, self.num_class)

            empi = [loss_z, loss_h]
            term3 = 0.
            for i in range(self.num_class):
                new_Z = torch.cat((Z[real_label == i], Z_bar[real_label == i]), 0)
                new_label = torch.cat(
                    (torch.zeros_like(real_label[real_label == i]),
                     torch.ones_like(real_label[real_label == i]))
                )
                loss, em = self.deltaR(new_Z, new_label, 2)
                term3 += loss
            empi = empi + [term3]
            errD = self.gam1 * loss_z + self.gam2 * loss_h + self.gam3 * term3

        elif self.train_mode == 1:
            print("has been dropped")
            raise NotImplementedError()

        elif self.train_mode == 0:
            new_Z = torch.cat((Z, Z_bar), 0)
            new_label = torch.cat((torch.zeros_like(real_label), torch.ones_like(real_label)))
            errD, empi = self.deltaR(new_Z, new_label, 2)
        else:
            raise ValueError()

        return errD, empi

    def debug(self, Z, Z_bar, real_label):

        print("===========================")

    def compute_discrimn_loss(self, Z):
        """Theoretical Discriminative Loss."""
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        scalar = d / (n * self.eps)
        logdet = torch.logdet(I + scalar * Z @ Z.T)
        return logdet / 2.

    def compute_compress_loss(self, Z, Pi):
        """Theoretical Compressive Loss."""
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        compress_loss = []
        scalars = []
        for j in range(Pi.shape[1]):
            Z_ = Z[:, Pi[:, j] == 1]
            trPi = Pi[:, j].sum() + 1e-8
            scalar = d / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * Z_ @ Z_.T)
            compress_loss.append(log_det)
            scalars.append(trPi / (2 * n))
        return compress_loss, scalars

    def deltaR(self, Z, Y, num_classes):
    
        if num_classes is None:
            num_classes = Y.max() + 1
            
        #print("classes:", num_classes)

        Pi = F.one_hot(Y, num_classes).to(Z.device)
        discrimn_loss = self.compute_discrimn_loss(Z.T)
        compress_loss, scalars = self.compute_compress_loss(Z.T, Pi)

        compress_term = 0.
        for z, s in zip(compress_loss, scalars):
            compress_term += s * z
        total_loss = discrimn_loss - compress_term

        return -total_loss, (discrimn_loss, compress_term, compress_loss, scalars)

    def gumb_compress_loss(self, Z, P):
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        compress_loss = 0.
        for j in range(self.num_class):
        
            #P[:, j:j+1][P[:, j:j+1]<threshold] = 0 
            
            Z_ = Z * P[:, j:j+1]
            trPi = P[:, j].sum() + 1e-8
            scalar = d / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * Z_ @ Z_.T)
            compress_loss += (trPi / (2 * n)) *log_det
        return compress_loss

    def pseudo_label_loss(self, Z, logits, thres = 1.4):
    
        logits = logits*thres

        P = F.gumbel_softmax(logits)

        discrimn_loss = self.compute_discrimn_loss(Z.T)
        compress_loss = self.gumb_compress_loss(Z, P)
        total_loss = discrimn_loss - compress_loss

        return -total_loss, (discrimn_loss, compress_loss)

================================================
FILE: model/model.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn

from torchvision.models import resnet18, resnet34, resnet50

from .resnet import Resnet10CIFAR

def getmodel(arch):
    
    #backbone = resnet18()
    
    if arch == "resnet18-cifar":
        backbone = resnet18()
        backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 
        backbone.maxpool = nn.Identity()
        backbone.fc = nn.Identity()
        return backbone, 512  
    elif arch == "resnet18-imagenet":
        backbone = resnet18()    
        backbone.fc = nn.Identity()
        return backbone, 512
    elif arch == "resnet18-tinyimagenet":
        backbone = resnet18()    
        backbone.avgpool = nn.AdaptiveAvgPool2d(1)
        backbone.fc = nn.Identity()
        return backbone, 512
    else:
        raise NameError("{} not found in network architecture".format(arch))
  

class encoder(nn.Module): 
     def __init__(self,z_dim=1024,hidden_dim=4096, norm_p=2, arch = "resnet18-cifar"):
        super().__init__()

        backbone, feature_dim = getmodel(arch)
        self.backbone = backbone
        self.norm_p = norm_p
        self.pre_feature = nn.Sequential(nn.Linear(feature_dim,hidden_dim),
                                         nn.BatchNorm1d(hidden_dim),
                                         nn.ReLU()
                                        )
        self.projection = nn.Sequential(nn.Linear(hidden_dim,hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim,z_dim))
        
          
     def forward(self, x, is_test = False):
         
        feature = self.backbone(x)
        feature = self.pre_feature(feature)
        z = F.normalize(self.projection(feature),p=self.norm_p)

        if is_test:
            return z, feature
        else:
            return z

   
    

================================================
FILE: model/resnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, resnet34, resnet50

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, blocks_config, first_config, first_pool=False):
        super(ResNet, self).__init__()
        #format of first_config
        [in_chan, chan, k, s] = first_config
        self.in_planes = chan
        self.conv1 = nn.Conv2d(in_chan, chan, kernel_size=k, stride=s,
                               padding=k//2, bias=False)
        self.bn1 = nn.BatchNorm2d(chan)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if first_pool else nn.Identity()
        self.layer1 = self._make_layer(block, blocks_config[0][0], blocks_config[0][1], stride=1)
        self.layer2 = self._make_layer(block, blocks_config[1][0], blocks_config[1][1], stride=2)
        self.layer3 = self._make_layer(block, blocks_config[2][0], blocks_config[2][1], stride=2)
        self.layer4 = self._make_layer(block, blocks_config[3][0], blocks_config[3][1], stride=2)
    
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.pool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        
        
        feature = out.mean((2,3))
        
        return feature
    

def Resnet10MNIST():
    block = BasicBlock
    blocks_config = [
        [64,1],[128,1],[256,1],[512,1]
    ]
    first_config = [1,64,3,1]
    return ResNet(block,blocks_config,first_config,first_pool=False)

def Resnet10CIFAR():
    block = BasicBlock
    blocks_config = [
        [32,1],[64,1],[128,1],[256,1] 
    ]
    first_config = [3,32,3,1]
    return ResNet(block,blocks_config,first_config,first_pool=True)

def Resnet18imgs():
    block = BasicBlock
    blocks_config = [
        [32,2],[64,2],[128,2],[256,2]
    ]
    first_config = [1,32,5,2]
    return ResNet(block,blocks_config,first_config,first_pool=True)

def Resnet18CIFAR():
    backbone = resnet18()
    backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    backbone.maxpool = nn.Identity()
    backbone.fc = nn.Identity()
    return backbone
    
def Resnet18STL10():
    block = BasicBlock
    blocks_config = [
        [64,2],[128,2],[256,2],[512,2]
    ]
    first_config = [3,64,5,2]
    return ResNet(block,blocks_config,first_config,first_pool=True)

def Resnet34CIFAR():
    backbone = resnet34()
    backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    backbone.maxpool = nn.Identity()
    backbone.fc = nn.Identity()
    return backbone

def Resnet34STL10():
    block = BasicBlock
    blocks_config = [
        [64,3],[128,4],[256,6],[512,3]
    ]
    first_config = [3,64,5,2]
    return ResNet(block,blocks_config,first_config,first_pool=True)

================================================
FILE: requirements.text
================================================
torch
torchvision
torchmetrics
numpy
tqdm
Pillow
matplotlib
scikit-learn
Download .txt
gitextract_sphqay9s/

├── README.md
├── dataset/
│   ├── aug.py
│   ├── aug4img.py
│   └── datasets.py
├── evaluate.py
├── func.py
├── lars.py
├── loss.py
├── main.py
├── mcr/
│   └── loss.py
├── model/
│   ├── model.py
│   └── resnet.py
└── requirements.text
Download .txt
SYMBOL INDEX (122 symbols across 11 files)

FILE: dataset/aug.py
  function load_transforms (line 9) | def load_transforms(name):
  class Solarization (line 77) | class Solarization:
    method __call__ (line 80) | def __call__(self, img: Image) -> Image:
  class GBlur (line 92) | class GBlur(object):
    method __init__ (line 93) | def __init__(self, p):
    method __call__ (line 96) | def __call__(self, img):
  class AddGaussianNoise (line 104) | class AddGaussianNoise(object):
    method __init__ (line 105) | def __init__(self, mean=0., std=1.):
    method __call__ (line 109) | def __call__(self, tensor):
    method __repr__ (line 112) | def __repr__(self):
  class ContrastiveLearningViewGenerator (line 116) | class ContrastiveLearningViewGenerator(object):
    method __init__ (line 117) | def __init__(self, num_patch = 4):
    method __call__ (line 121) | def __call__(self, x):

FILE: dataset/aug4img.py
  class Solarization (line 12) | class Solarization:
    method __call__ (line 15) | def __call__(self, img: Image) -> Image:
  class GBlur (line 27) | class GBlur(object):
    method __init__ (line 28) | def __init__(self, p):
    method __call__ (line 31) | def __call__(self, img):
  class AddGaussianNoise (line 39) | class AddGaussianNoise(object):
    method __init__ (line 40) | def __init__(self, mean=0., std=1.):
    method __call__ (line 44) | def __call__(self, tensor):
    method __repr__ (line 47) | def __repr__(self):
  class ContrastiveLearningViewGenerator (line 51) | class ContrastiveLearningViewGenerator(object):
    method __init__ (line 52) | def __init__(self, num_patch = 4):
    method __call__ (line 56) | def __call__(self, x):

FILE: dataset/datasets.py
  function load_dataset (line 5) | def load_dataset(data_name, train=True, num_patch = 4, path="./data/"):
  function sparse2coarse (line 48) | def sparse2coarse(targets):

FILE: evaluate.py
  function compute_accuracy (line 60) | def compute_accuracy(y_pred, y_true):
  function chunk_avg (line 68) | def chunk_avg(x,n_chunks=2,normalize=False):
  function test (line 77) | def test(net, train_loader, test_loader):
  function chunk_avg (line 136) | def chunk_avg(x,n_chunks=2,normalize=False):

FILE: func.py
  class WeightedKNNClassifier (line 24) | class WeightedKNNClassifier(Metric):
    method __init__ (line 25) | def __init__(
    method update (line 62) | def update(
    method set_tk (line 90) | def set_tk(self, T, k):
    method compute (line 95) | def compute(self) -> Tuple[float]:
  function linear (line 181) | def linear(train_features, train_labels, test_features, test_labels, lr=...
  function accuracy (line 240) | def accuracy(output, target, topk=(1,)):
  class tensor_dataset (line 256) | class tensor_dataset(data.Dataset):
    method __init__ (line 257) | def __init__(self,x,y):
    method __getitem__ (line 262) | def __getitem__(self,indx):
    method __len__ (line 265) | def __len__(self):
  function set_gamma (line 271) | def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0...
  function warmup_lr (line 280) | def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):
  function marginal_H (line 285) | def marginal_H(logits):
  function chunk_avg (line 292) | def chunk_avg(x,n_chunks=2,normalize=False):
  function cluster_match (line 300) | def cluster_match(cluster_mtx,label_mtx,n_classes=10,print_result=True):
  function cluster_merge_match (line 330) | def cluster_merge_match(cluster_mtx,label_mtx,print_result=True):
  function cluster_acc (line 344) | def cluster_acc(test_loader,net,device,print_result=False,save_name_img=...
  function save_cluster_imgs (line 378) | def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):
  function save_latent_pca_figure (line 399) | def save_latent_pca_figure(z_mtx,cluster_mtx,save_name):
  function analyze_latent (line 418) | def analyze_latent(z_mtx,cluster_mtx):

FILE: lars.py
  class LARS (line 5) | class LARS(Optimizer):
    method __init__ (line 26) | def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov...
    method step (line 48) | def step(self, epoch=None, closure=None):
  class LARSWrapper (line 146) | class LARSWrapper:
    method __init__ (line 147) | def __init__(
    method defaults (line 184) | def defaults(self):
    method defaults (line 188) | def defaults(self, defaults):
    method __class__ (line 192) | def __class__(self):
    method state (line 196) | def state(self):
    method state (line 200) | def state(self, state):
    method param_groups (line 204) | def param_groups(self):
    method param_groups (line 208) | def param_groups(self, value):
    method step (line 212) | def step(self, closure=None):
    method update_p (line 234) | def update_p(self, p, group, weight_decay):

FILE: loss.py
  class contrastive_loss (line 5) | class contrastive_loss(nn.Module):
    method __init__ (line 6) | def __init__(self):
    method forward (line 9) | def forward(self,x,labels):
  class SimCLR (line 15) | class SimCLR(nn.Module):
    method __init__ (line 16) | def __init__(self,temperature=0.5,n_views=2,contrastive=False):
    method info_nce_loss (line 26) | def info_nce_loss(self,X):
    method forward (line 60) | def forward(self,X):
  class Z_loss (line 65) | class Z_loss(nn.Module):
    method __init__ (line 66) | def __init__(self,):
    method forward (line 70) | def forward(self,z):
  class TotalCodingRate (line 76) | class TotalCodingRate(nn.Module):
    method __init__ (line 77) | def __init__(self, eps=0.01):
    method compute_discrimn_loss (line 81) | def compute_discrimn_loss(self, W):
    method forward (line 89) | def forward(self,X):
  class MaximalCodingRateReduction (line 92) | class MaximalCodingRateReduction(torch.nn.Module):
    method __init__ (line 93) | def __init__(self, eps=0.01, gamma=1):
    method compute_discrimn_loss (line 98) | def compute_discrimn_loss(self, W):
    method compute_compress_loss (line 106) | def compute_compress_loss(self, W, Pi):
    method forward (line 118) | def forward(self, X, Y, num_classes=None):

FILE: main.py
  function chunk_avg (line 67) | def chunk_avg(x,n_chunks=2,normalize=False):
  class Similarity_Loss (line 76) | class Similarity_Loss(nn.Module):
    method __init__ (line 77) | def __init__(self, ):
    method forward (line 81) | def forward(self, z_list, z_avg):
  function cal_TCR (line 96) | def cal_TCR(z, criterion, num_patches):
  function main (line 146) | def main():

FILE: mcr/loss.py
  class MCRGANloss (line 6) | class MCRGANloss(nn.Module):
    method __init__ (line 8) | def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000...
    method forward (line 18) | def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
    method old_version (line 25) | def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_...
    method debug (line 63) | def debug(self, Z, Z_bar, real_label):
    method compute_discrimn_loss (line 67) | def compute_discrimn_loss(self, Z):
    method compute_compress_loss (line 75) | def compute_compress_loss(self, Z, Pi):
    method deltaR (line 90) | def deltaR(self, Z, Y, num_classes):
    method gumb_compress_loss (line 108) | def gumb_compress_loss(self, Z, P):
    method pseudo_label_loss (line 123) | def pseudo_label_loss(self, Z, logits, thres = 1.4):

FILE: model/model.py
  function getmodel (line 9) | def getmodel(arch):
  class encoder (line 32) | class encoder(nn.Module):
    method __init__ (line 33) | def __init__(self,z_dim=1024,hidden_dim=4096, norm_p=2, arch = "resnet...
    method forward (line 46) | def forward(self, x, is_test = False):

FILE: model/resnet.py
  class BasicBlock (line 7) | class BasicBlock(nn.Module):
    method __init__ (line 9) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 26) | def forward(self, x):
  class Bottleneck (line 33) | class Bottleneck(nn.Module):
    method __init__ (line 36) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 55) | def forward(self, x):
  class ResNet (line 64) | class ResNet(nn.Module):
    method __init__ (line 65) | def __init__(self, block, blocks_config, first_config, first_pool=False):
    method _make_layer (line 80) | def _make_layer(self, block, planes, num_blocks, stride):
    method forward (line 88) | def forward(self, x):
  function Resnet10MNIST (line 103) | def Resnet10MNIST():
  function Resnet10CIFAR (line 111) | def Resnet10CIFAR():
  function Resnet18imgs (line 119) | def Resnet18imgs():
  function Resnet18CIFAR (line 127) | def Resnet18CIFAR():
  function Resnet18STL10 (line 134) | def Resnet18STL10():
  function Resnet34CIFAR (line 142) | def Resnet34CIFAR():
  function Resnet34STL10 (line 149) | def Resnet34STL10():
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (71K chars).
[
  {
    "path": "README.md",
    "chars": 5633,
    "preview": "# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch\n\n[![arXiv](https://img.shields.io/badge/arXiv-2304.039"
  },
  {
    "path": "dataset/aug.py",
    "chars": 4786,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport torchvision.transforms as t"
  },
  {
    "path": "dataset/aug4img.py",
    "chars": 2111,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport torchvision.transforms as "
  },
  {
    "path": "dataset/datasets.py",
    "chars": 2805,
    "preview": "import os\nimport numpy as np\nimport torchvision\n\ndef load_dataset(data_name, train=True, num_patch = 4, path=\"./data/\"):"
  },
  {
    "path": "evaluate.py",
    "chars": 5297,
    "preview": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\nfro"
  },
  {
    "path": "func.py",
    "chars": 16433,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics imp"
  },
  {
    "path": "lars.py",
    "chars": 8346,
    "preview": "import torch\nimport torch.optim as optim\nfrom torch.optim.optimizer import Optimizer, required\n\nclass LARS(Optimizer):\n "
  },
  {
    "path": "loss.py",
    "chars": 4898,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass contrastive_loss(nn.Module):\n    def __init__("
  },
  {
    "path": "main.py",
    "chars": 6056,
    "preview": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nimport torch.optim as optim\nimport os\nfrom "
  },
  {
    "path": "mcr/loss.py",
    "chars": 4464,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MCRGANloss(nn.Module):\n\n    def __init__(self"
  },
  {
    "path": "model/model.py",
    "chars": 1846,
    "preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nfrom torchvision.models import resnet18, resnet34, r"
  },
  {
    "path": "model/resnet.py",
    "chars": 5439,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision.models import resnet18, resnet34, r"
  },
  {
    "path": "requirements.text",
    "chars": 72,
    "preview": "torch\ntorchvision\ntorchmetrics\nnumpy\ntqdm\nPillow\nmatplotlib\nscikit-learn"
  }
]

About this extraction

This page contains the full source code of the tsb0601/EMP-SSL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (66.6 KB), approximately 17.3k tokens, and a symbol index with 122 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!