Repository: tsb0601/EMP-SSL
Branch: main
Commit: 8fc3758617e0
Files: 13
Total size: 66.6 KB
Directory structure:
gitextract_sphqay9s/
├── README.md
├── dataset/
│ ├── aug.py
│ ├── aug4img.py
│ └── datasets.py
├── evaluate.py
├── func.py
├── lars.py
├── loss.py
├── main.py
├── mcr/
│ └── loss.py
├── model/
│ ├── model.py
│ └── resnet.py
└── requirements.text
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch
[](https://arxiv.org/abs/2304.03977)

Authors: Shengbang Tong*, Yubei Chen*, Yi Ma, Yann LeCun
## Introduction
This repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch." The paper introduces a simplistic but efficient self-supervised learning method called Extreme-Multi-Patch Self-Supervised-Learning (EMP-SSL). EMP-SSL significantly reduces the training epochs required for convergence by increasing the number of fix size image patches from each image instance.
## Preparing Training Data
Cifar10 and Cifar100 can be downloaded automatically in the script. ImageNet100 is a special subset of ImageNet. Details can be found in this [link](https://github.com/HobbitLong/CMC/issues/21).
## Getting Started
Current code implementation supports Cifar10, Cifar100 and ImageNet100.
To get started with the EMP-SSL implementation, follow these instructions:
### 1. Clone this repository
```bash
git clone https://github.com/tsb0601/emp-ssl.git
cd emp-ssl
```
### 2. Install required packages
```
pip install -r requirements.txt
```
### 3. Training
#### Reproducing 1-epoch results
| | CIFAR-10
1 Epoch | CIFAR-100
1 Epoch | Tiny ImageNet
1 epochs | ImageNet-100
1 epochs |
|--------------------|:----------------------:|:-----------------------:|:----------------------------:|:--------------------------:|
| EMP-SSL (1 Epoch) | 0.842 | 0.585 | 0.381 | 0.585 |
For CIFAR10 or CIFAR100
```
python main.py --data cifar10 --epoch 2 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```
For ImageNet100
```
python main.py --data imagenet100 --epoch 2 --patch_sim 200 --arch 'resnet18-imagenet' --num_patches 20 --lr 0.3
```
#### Reproducing multi epochs results
| | CIFAR-10
1 Epoch | CIFAR-10
10 Epochs | CIFAR-10
30 Epochs | CIFAR-10
1000 Epochs | CIFAR-100
1 Epoch | CIFAR-100
10 Epochs | CIFAR-100
30 Epochs | CIFAR-100
1000 Epochs | Tiny ImageNet
10 Epochs | Tiny ImageNet
1000 Epochs |ImageNet-100
10 Epochs | ImageNet-100
400 Epochs |
|----------------------|:-------------------:|:---------------------:|:---------------------:|:-----------------------:|:--------------------:|:----------------------:|:----------------------:|:------------------------:| :------------------------:|:------------------------:|:------------------------:| :------------------------:|
| SimCLR | 0.282 | 0.565 | 0.663 | 0.910 | 0.054 | 0.185 | 0.341 | 0.662 | - | 0.488 | - | 0.776
| BYOL | 0.249 | 0.489 | 0.684 | 0.926 | 0.043 | 0.150 | 0.349 | 0.708 | - | 0.510 | - | 0.802
| VICReg | 0.406 | 0.697 | 0.781 | 0.921 | 0.079 | 0.319 | 0.479 | 0.685 | - | - | - | 0.792
| SwAV | 0.245 | 0.532 | 0.767 | 0.923 | 0.028 | 0.208 | 0.294 | 0.658 |- | - | - | 0.740
| ReSSL | 0.245 | 0.256 | 0.525 | 0.914 | 0.033 | 0.122 | 0.247 | 0.674 |- | - | - | 0.769
| EMP-SSL (20 patches) | 0.806 | 0.907 | 0.931 | - | 0.551 | 0.678 | 0.724 | - | - | - | - | -
| EMP-SSL (200 patches)| 0.826* | 0.915 | 0.934 | - | 0.577 | 0.701 | 0.733 | - | 0.515 | - | 0.789 | -
\* Here, we change learning rate schedule to decay in 30 epochs, so 1 epoch accuracy will be slightly lower than optimizing for 1-epoch training.
Change num_patches here to change the number of patches used in EMP-SSL training.
```
python main.py --data cifar10 --epoch 30 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```
### 4. Evaluating
Because our model is trained with only fixed size image patches. To evaluate the performance, we adopt bag-of-features model from intra-instance VICReg paper. Change test_patches here to adjust number of patches used in bag-of-feature model for different GPUs.
```
python evaluate.py --model_path 'path to your evaluated model' --test_patches 128
```
## Acknowledgment
This repo is inspired by [MCR2](https://github.com/Ma-Lab-Berkeley/MCR2), [solo-learn](https://github.com/vturrisi/solo-learn) and [NMCE](https://github.com/zengyi-li/NMCE-release) repo.
## Citation
If you find this repository useful, please consider giving a star :star: and citation:
```
@article{tong2023empssl,
title={EMP-SSL: Towards Self-Supervised Learning in One Training Epoch},
author={Shengbang Tong and Yubei Chen and Yi Ma and Yann Lecun},
journal={arXiv preprint arXiv:2304.03977},
year={2023}
}
```
================================================
FILE: dataset/aug.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
def load_transforms(name):
"""Load data transformations.
Note:
- Gaussian Blur is defined at the bottom of this file.
"""
_name = name.lower()
if _name == "cifar_sup":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.RandomCrop(32, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.ToTensor(),normalize])
elif _name == "cifar_patch":
normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
aug_transform = transforms.Compose([
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.ToTensor(), normalize])
elif _name == "cifar_simclr_norm":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.ToTensor(),normalize])
elif _name == "cifar_byol":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(
(32, 32),
scale=(0.2, 1.0),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([Solarization()], p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
# transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.ToTensor(),normalize])
else:
raise NameError("{} not found in transform loader".format(name))
return aug_transform, baseline_transform
class Solarization:
"""Solarization as a callable object."""
def __call__(self, img: Image) -> Image:
"""Applies solarization to an input image.
Args:
img (Image): an image in the PIL.Image format.
Returns:
Image: a solarized image.
"""
return ImageOps.solarize(img)
class GBlur(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if np.random.rand() < self.p:
sigma = np.random.rand() * 1.9 + 0.1
return img.filter(ImageFilter.GaussianBlur(sigma))
else:
return img
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ContrastiveLearningViewGenerator(object):
def __init__(self, num_patch = 4):
self.num_patch = num_patch
def __call__(self, x):
normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GBlur(p=0.1),
transforms.RandomApply([Solarization()], p=0.1),
transforms.ToTensor(),
normalize
])
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
return augmented_x
================================================
FILE: dataset/aug4img.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
from torchvision.transforms import InterpolationMode
class Solarization:
"""Solarization as a callable object."""
def __call__(self, img: Image) -> Image:
"""Applies solarization to an input image.
Args:
img (Image): an image in the PIL.Image format.
Returns:
Image: a solarized image.
"""
return ImageOps.solarize(img)
class GBlur(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if np.random.rand() < self.p:
sigma = np.random.rand() * 1.9 + 0.1
return img.filter(ImageFilter.GaussianBlur(sigma))
else:
return img
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ContrastiveLearningViewGenerator(object):
def __init__(self, num_patch = 4):
self.num_patch = num_patch
def __call__(self, x):
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(
224, scale=(0.25, 0.25), interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GBlur(p=0.1),
transforms.RandomApply([Solarization()], p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
return augmented_x
================================================
FILE: dataset/datasets.py
================================================
import os
import numpy as np
import torchvision
def load_dataset(data_name, train=True, num_patch = 4, path="./data/"):
"""Loads a dataset for training and testing. If augmentloader is used, transform should be None.
Parameters:
data_name (str): name of the dataset
transform_name (torchvision.transform): name of transform to be applied (see aug.py)
use_baseline (bool): use baseline transform or augmentation transform
train (bool): load training set or not
contrastive (bool): whether to convert transform to multiview augmentation for contrastive learning.
n_views (bool): number of views for contrastive learning
path (str): path to dataset base path
Returns:
dataset (torch.data.dataset)
"""
_name = data_name.lower()
if _name == "imagenet":
from .aug4img import ContrastiveLearningViewGenerator
else:
from .aug import ContrastiveLearningViewGenerator
transform = ContrastiveLearningViewGenerator(num_patch = num_patch)
if _name == "cifar10":
trainset = torchvision.datasets.CIFAR10(root=os.path.join(path, "CIFAR10"), train=train, download=True, transform=transform)
trainset.num_classes = 10
elif _name == "cifar100":
trainset = torchvision.datasets.CIFAR100(root=os.path.join(path, "CIFAR100"), train=train, download=True, transform=transform)
trainset.num_classes = 100
elif _name == "imagenet":
if train:
trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/train100/",transform=transform)
#trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/train/",transform=transform)
else:
trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/val100/",transform=transform)
#trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/val/",transform=transform)
trainset.num_classes = 200
else:
raise NameError("{} not found in trainset loader".format(_name))
return trainset
def sparse2coarse(targets):
"""CIFAR100 Coarse Labels. """
coarse_targets = [ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3,
9, 7, 11, 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 0, 11, 1, 10,
12, 14, 16, 9, 11, 5, 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16,
4, 17, 4, 2, 0, 17, 4, 18, 17, 10, 3, 2, 12, 12, 16, 12, 1,
9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 16, 19, 2, 4, 6,
19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]
return np.array(coarse_targets)[targets]
================================================
FILE: evaluate.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import torch
import numpy as np
from func import WeightedKNNClassifier, linear
######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--test_patches', type=int, default=128,
help='number of patches used in testing (default: 128)')
parser.add_argument('--data', type=str, default="cifar10",
help='dataset (default: cifar10)')
parser.add_argument('--arch', type=str, default="resnet18-cifar",
help='network architecture (default: resnet18-cifar)')
parser.add_argument('--lr', type=float, default=0.03,
help='learning rate for linear eval (default: 0.03)')
parser.add_argument('--linear', type=bool, default=True,
help='use linear eval or not')
parser.add_argument('--knn', help='evaluate using kNN measuring cosine similarity', action='store_true')
parser.add_argument('--model_path', type=str, default="",
help='model directory for eval')
args = parser.parse_args()
######################
## Testing Accuracy ##
######################
test_patches = args.test_patches
def compute_accuracy(y_pred, y_true):
"""Compute accuracy by counting correct classification. """
assert y_pred.shape == y_true.shape
return 1 - np.count_nonzero(y_pred - y_true) / y_true.size
knn_classifier = WeightedKNNClassifier()
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
def test(net, train_loader, test_loader):
train_z_full_list, train_y_list, test_z_full_list, test_y_list = [], [], [], []
with torch.no_grad():
for x, y in tqdm(train_loader):
x = torch.cat(x, dim = 0)
z_proj, z_pre = net(x, is_test=True)
z_pre = chunk_avg(z_pre, test_patches)
z_pre = z_pre.detach().cpu()
train_z_full_list.append(z_pre)
knn_classifier.update(train_features = z_pre, train_targets = y)
train_y_list.append(y)
for x, y in tqdm(test_loader):
x = torch.cat(x, dim = 0)
z_proj, z_pre = net(x, is_test=True)
z_pre = chunk_avg(z_pre, test_patches)
z_pre = z_pre.detach().cpu()
test_z_full_list.append(z_pre)
knn_classifier.update(test_features = z_pre, test_targets = y)
test_y_list.append(y)
train_features_full, train_labels, test_features_full, test_labels = torch.cat(train_z_full_list,dim=0), torch.cat(train_y_list,dim=0), torch.cat(test_z_full_list,dim=0), torch.cat(test_y_list,dim=0)
if args.data == "cifar10":
num_classes = 10
elif args.data == "cifar100":
num_classes = 100
elif args.data == "tinyimagenet200":
num_classes = 200
elif args.data == "imagenet100":
num_classes = 100
elif args.data == "imagenet":
num_classes = 1000
if args.linear:
print("Using Linear Eval to evaluate accuracy")
linear(train_features_full, train_labels, test_features_full, test_labels, lr=args.lr, num_classes = num_classes)
if args.knn:
print("Using KNN to evaluate accuracy")
top1, top5 = knn_classifier.compute()
print("KNN (top1/top5):", top1, top5)
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
torch.multiprocessing.set_sharing_strategy('file_system')
#Get Dataset
if args.data == "imagenet100" or args.data == "imagenet":
memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)
test_data = load_dataset(args.data, train=False, num_patch = test_patches)
test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)
else:
memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)
test_data = load_dataset(args.data, train=False, num_patch = test_patches)
test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)
# Load Model and Checkpoint
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
save_dict = torch.load(args.model_path)
net.load_state_dict(save_dict,strict=False)
net.cuda()
net.eval()
test(net, memory_loader, test_loader)
================================================
FILE: func.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import torchvision
# import torch.nn
from torch import nn, optim
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader
from typing import Tuple
import torch.nn.functional as F
from torchmetrics.metric import Metric
class WeightedKNNClassifier(Metric):
def __init__(
self,
k: int = 20,
T: float = 0.07,
max_distance_matrix_size: int = int(5e6),
distance_fx: str = "cosine",
epsilon: float = 0.00001,
dist_sync_on_step: bool = False,
):
"""Implements the weighted k-NN classifier used for evaluation.
Args:
k (int, optional): number of neighbors. Defaults to 20.
T (float, optional): temperature for the exponential. Only used with cosine
distance. Defaults to 0.07.
max_distance_matrix_size (int, optional): maximum number of elements in the
distance matrix. Defaults to 5e6.
distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
"euclidean". Defaults to "cosine".
epsilon (float, optional): Small value for numerical stability. Only used with
euclidean distance. Defaults to 0.00001.
dist_sync_on_step (bool, optional): whether to sync distributed values at every
step. Defaults to False.
"""
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
self.k = k
self.T = T
self.max_distance_matrix_size = max_distance_matrix_size
self.distance_fx = distance_fx
self.epsilon = epsilon
self.add_state("train_features", default=[], persistent=False)
self.add_state("train_targets", default=[], persistent=False)
self.add_state("test_features", default=[], persistent=False)
self.add_state("test_targets", default=[], persistent=False)
def update(
self,
train_features: torch.Tensor = None,
train_targets: torch.Tensor = None,
test_features: torch.Tensor = None,
test_targets: torch.Tensor = None,
):
"""Updates the memory banks. If train (test) features are passed as input, the
corresponding train (test) targets must be passed as well.
Args:
train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
"""
assert (train_features is None) == (train_targets is None)
assert (test_features is None) == (test_targets is None)
if train_features is not None:
assert train_features.size(0) == train_targets.size(0)
self.train_features.append(train_features.detach())
self.train_targets.append(train_targets.detach())
if test_features is not None:
assert test_features.size(0) == test_targets.size(0)
self.test_features.append(test_features.detach())
self.test_targets.append(test_targets.detach())
def set_tk(self, T, k):
self.T = T
self.k = k
@torch.no_grad()
def compute(self) -> Tuple[float]:
"""Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
the weight is computed using the exponential of the temperature scaled cosine
distance of the samples. If euclidean distance is selected, the weight corresponds
to the inverse of the euclidean distance.
Returns:
Tuple[float]: k-NN accuracy @1 and @5.
"""
#print(self.T, self.k)
train_features = torch.cat(self.train_features)
train_targets = torch.cat(self.train_targets)
test_features = torch.cat(self.test_features)
test_targets = torch.cat(self.test_targets)
if self.distance_fx == "cosine":
train_features = F.normalize(train_features)
test_features = F.normalize(test_features)
num_classes = torch.unique(test_targets).numel()
num_train_images = train_targets.size(0)
num_test_images = test_targets.size(0)
num_train_images = train_targets.size(0)
chunk_size = min(
max(1, self.max_distance_matrix_size // num_train_images),
num_test_images,
)
k = min(self.k, num_train_images)
top1, top5, total = 0.0, 0.0, 0
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
for idx in range(0, num_test_images, chunk_size):
# get the features for test images
features = test_features[idx : min((idx + chunk_size), num_test_images), :]
targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
batch_size = targets.size(0)
# calculate the dot product and compute top-k neighbors
if self.distance_fx == "cosine":
similarities = torch.mm(features, train_features.t())
elif self.distance_fx == "euclidean":
similarities = 1 / (torch.cdist(features, train_features) + self.epsilon)
else:
raise NotImplementedError
similarities, indices = similarities.topk(k, largest=True, sorted=True)
candidates = train_targets.view(1, -1).expand(batch_size, -1)
retrieved_neighbors = torch.gather(candidates, 1, indices)
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
if self.distance_fx == "cosine":
similarities = similarities.clone().div_(self.T).exp_()
probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
similarities.view(batch_size, -1, 1),
),
1,
)
_, predictions = probs.sort(1, True)
# find the predictions that match the target
correct = predictions.eq(targets.data.view(-1, 1))
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
top5 = (
top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()
) # top5 does not make sense if k < 5
total += targets.size(0)
top1 = top1 * 100.0 / total
top5 = top5 * 100.0 / total
self.reset()
return top1, top5
def linear(train_features, train_labels, test_features, test_labels, lr=0.0075, num_classes = 100):
train_data = tensor_dataset(train_features,train_labels)
test_data = tensor_dataset(test_features,test_labels)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=100, shuffle=True, drop_last=False, num_workers=2)
LL = nn.Linear(train_features.shape[1],num_classes)
optimizer = torch.optim.SGD(LL.parameters(), lr=lr, momentum=0.9, weight_decay=5e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
criterion = torch.nn.CrossEntropyLoss()
test_acc_list = []
for epoch in range(100):
top1_train_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(train_loader):
x_batch = x_batch
y_batch = y_batch
logits = LL(x_batch)
loss = criterion(logits, y_batch)
top1 = accuracy(logits, y_batch, topk=(1,))
top1_train_accuracy += top1[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
top1_train_accuracy /= (counter + 1)
top1_accuracy = 0
top5_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(test_loader):
x_batch = x_batch
y_batch = y_batch
logits = LL(x_batch)
top1, top5 = accuracy(logits, y_batch, topk=(1,5))
top1_accuracy += top1[0]
top5_accuracy += top5[0]
top1_accuracy /= (counter + 1)
top5_accuracy /= (counter + 1)
test_acc_list.append(top1_accuracy)
print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")
acc_vect = torch.tensor(test_acc_list)
print('best linear test acc {}, last acc {}'.format(acc_vect.max().item(),acc_vect[-1].item()))
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class tensor_dataset(data.Dataset):
def __init__(self,x,y):
self.x = x
self.y = y
self.length = x.shape[0]
def __getitem__(self,indx):
return self.x[indx], self.y[indx]
def __len__(self):
return self.length
def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0.,gamma_max=1.0):
warmup_start = total_epoch - warmup_epoch
warmup_end = total_epoch
if warmup_start < epoch<=warmup_end:
loss_fn.gamma = ((epoch - warmup_start)/(warmup_end - warmup_start))*(gamma_max - gamma_min) + gamma_min
else:
loss_fn.gamma = gamma_min
def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):
if epoch0:
max_label_list = []
max_count_list = []
for indx in cluster_indx:
#calculate highest number of matchs
mask = cluster_mtx==indx
label_elements, counts = label_mtx[mask].unique(return_counts=True)
for assigned_label in assigned_label_list:
counts[label_elements==assigned_label] = 0
max_count_list.append(counts.max())
max_label_list.append(label_elements[counts.argmax()])
max_label = torch.stack(max_label_list)
max_count = torch.stack(max_count_list)
assigned_label_list.append(max_label[max_count.argmax()])
assigned_count.append(max_count.max())
cluster_indx.pop(max_count.argmax())
total_correct = torch.tensor(assigned_count).sum().item()
total_sample = cluster_mtx.shape[0]
acc = total_correct/total_sample
if print_result:
print('{}/{} ({}%) correct'.format(total_correct,total_sample,acc*100))
else:
return total_correct, total_sample, acc
def cluster_merge_match(cluster_mtx,label_mtx,print_result=True):
cluster_indx = list(cluster_mtx.unique())
n_correct = 0
for cluster_id in cluster_indx:
label_elements, counts = label_mtx[cluster_mtx==cluster_id].unique(return_counts=True)
n_correct += counts.max()
total_sample = len(cluster_mtx)
acc = n_correct.item()/total_sample
if print_result:
print('{}/{} ({}%) correct'.format(n_correct,total_sample,acc*100))
else:
return n_correct, total_sample, acc
def cluster_acc(test_loader,net,device,print_result=False,save_name_img='cluster_img',save_name_fig='pca_figure'):
cluster_list = []
label_list = []
x_list = []
z_list = []
net.eval()
for x, y in test_loader:
with torch.no_grad():
x, y = x.float().to(device), y.to(device)
z, logit = net(x)
if logit.sum() == 0:
logit += torch.randn_like(logit)
cluster_list.append(logit.max(dim=1)[1].cpu())
label_list.append(y.cpu())
x_list.append(x.cpu())
z_list.append(z.cpu())
net.train()
cluster_mtx = torch.cat(cluster_list,dim=0)
label_mtx = torch.cat(label_list,dim=0)
x_mtx = torch.cat(x_list,dim=0)
z_mtx = torch.cat(z_list,dim=0)
_, _, acc_single = cluster_match(cluster_mtx,label_mtx,n_classes=label_mtx.max()+1,print_result=False)
_, _, acc_merge = cluster_merge_match(cluster_mtx,label_mtx,print_result=False)
NMI = normalized_mutual_info_score(label_mtx.numpy(),cluster_mtx.numpy())
ARI = adjusted_rand_score(label_mtx.numpy(),cluster_mtx.numpy())
if print_result:
print('cluster match acc {}, cluster merge match acc {}, NMI {}, ARI {}'.format(acc_single,acc_merge,NMI,ARI))
save_name_img += '_acc'+ str(acc_single)[2:5]
save_cluster_imgs(cluster_mtx,x_mtx,save_name_img)
save_latent_pca_figure(z_mtx,cluster_mtx,save_name_fig)
return acc_single, acc_merge, NMI, ARI
def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):
cluster_indexs, counts = cluster_mtx.unique(return_counts=True)
x_list = []
counts_list = []
for i, c_indx in enumerate(cluster_indexs):
if counts[i]>npercluster:
x_list.append(x_mtx[cluster_mtx==c_indx,:,:,:])
counts_list.append(counts[i])
n_clusters = len(counts_list)
fig, ax = plt.subplots(n_clusters,1,dpi=80,figsize=(1.2*n_clusters, 3*n_clusters))
for i, ax in enumerate(ax):
img = torchvision.utils.make_grid(x_list[i][:npercluster],nrow=npercluster//5,normalize=True)
ax.imshow(img.permute(1,2,0))
ax.set_axis_off()
ax.set_title('Cluster with {} images'.format(counts_list[i]))
fig.savefig(save_name+'.pdf')
plt.close(fig)
def save_latent_pca_figure(z_mtx,cluster_mtx,save_name):
_, s_z_all, _ = z_mtx.svd()
cluster_n = []
cluster_s = []
for cluster_indx in cluster_mtx.unique():
_, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
cluster_n.append((cluster_mtx==cluster_indx).sum().item())
cluster_s.append(s_cluster/s_cluster.max())
#make plot
fig, ax = plt.subplots(1,2,figsize=(9, 3))
ax[0].plot(s_z_all)
for i, s_curve in enumerate(cluster_s):
ax[1].plot(s_curve,label=cluster_n[i])
ax[1].set_xlim(xmin=0,xmax=20)
ax[1].legend()
fig.savefig(save_name +'.pdf')
plt.close(fig)
def analyze_latent(z_mtx,cluster_mtx):
_, s_z_all, _ = z_mtx.svd()
cluster_n = []
cluster_s = []
cluster_d = []
for cluster_indx in cluster_mtx.unique():
_, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
s_cluster = s_cluster/s_cluster.max()
cluster_n.append((cluster_mtx==cluster_indx).sum().item())
cluster_s.append(s_cluster)
# print(list(cluster_s))
print(s_cluster)
# s_diff = s_cluster[:-1] - s_cluster[1:]
# cluster_d.append(s_diff.max(0)[1])
cluster_d.append((s_cluster>0.01).sum())
for i in range(len(cluster_n)):
print('subspace {}, dimension {}, samples {}'.format(i,cluster_d[i],cluster_n[i]))
================================================
FILE: lars.py
================================================
import torch
import torch.optim as optim
from torch.optim.optimizer import Optimizer, required
class LARS(Optimizer):
"""
Layer-wise adaptive rate scaling
- Converted from Tensorflow to Pytorch from:
https://github.com/google-research/simclr/blob/master/lars_optimizer.py
- Based on:
https://github.com/noahgolmant/pytorch-lars
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): base learning rate (\gamma_0)
lr (int): Length / Number of layers we want to apply weight decay, else do not compute
momentum (float, optional): momentum factor (default: 0.9)
use_nesterov (bool, optional): flag to use nesterov momentum (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
("\beta")
eta (float, optional): LARS coefficient (default: 0.001)
- Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
- Large Batch Training of Convolutional Networks:
https://arxiv.org/abs/1708.03888
"""
def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001):
self.epoch = 0
defaults = dict(
lr=lr,
momentum=momentum,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
classic_momentum=classic_momentum,
eta=eta,
len_reduced=len_reduced
)
super(LARS, self).__init__(params, defaults)
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.use_nesterov = use_nesterov
self.classic_momentum = classic_momentum
self.eta = eta
self.len_reduced = len_reduced
def step(self, epoch=None, closure=None):
loss = None
if closure is not None:
loss = closure()
if epoch is None:
epoch = self.epoch
self.epoch += 1
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
eta = group['eta']
learning_rate = group['lr']
# TODO: Hacky
counter = 0
for p in group['params']:
if p.grad is None:
continue
param = p.data
grad = p.grad.data
param_state = self.state[p]
# TODO: This really hacky way needs to be improved.
# Note Excluded are passed at the end of the list to are ignored
if counter < self.len_reduced:
grad += self.weight_decay * param
# Create parameter for the momentum
if "momentum_var" not in param_state:
next_v = param_state["momentum_var"] = torch.zeros_like(
p.data
)
else:
next_v = param_state["momentum_var"]
if self.classic_momentum:
trust_ratio = 1.0
# TODO: implementation of layer adaptation
w_norm = torch.norm(param)
g_norm = torch.norm(grad)
device = g_norm.get_device()
trust_ratio = torch.where(w_norm.ge(0), torch.where(
g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)),
torch.Tensor([1.0]).to(device)).item()
scaled_lr = learning_rate * trust_ratio
grad_scaled = scaled_lr*grad
next_v.mul_(momentum).add_(grad_scaled)
if self.use_nesterov:
update = (self.momentum * next_v) + (scaled_lr * grad)
else:
update = next_v
p.data.add_(-update)
# Not classic_momentum
else:
next_v.mul_(momentum).add_(grad)
if self.use_nesterov:
update = (self.momentum * next_v) + (grad)
else:
update = next_v
trust_ratio = 1.0
# TODO: implementation of layer adaptation
w_norm = torch.norm(param)
v_norm = torch.norm(update)
device = v_norm.get_device()
trust_ratio = torch.where(w_norm.ge(0), torch.where(
v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)),
torch.Tensor([1.0]).to(device)).item()
scaled_lr = learning_rate * trust_ratio
p.data.add_(-scaled_lr * update)
counter += 1
return loss
#LARSWrapper from solo-learn repo...
class LARSWrapper:
def __init__(
self,
optimizer: Optimizer,
eta: float = 1e-3,
clip: bool = False,
eps: float = 1e-8,
exclude_bias_n_norm: bool = False,
):
"""Wrapper that adds LARS scheduling to any optimizer.
This helps stability with huge batch sizes.
Args:
optimizer (Optimizer): torch optimizer.
eta (float, optional): trust coefficient. Defaults to 1e-3.
clip (bool, optional): clip gradient values. Defaults to False.
eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
Defaults to False.
"""
self.optim = optimizer
self.eta = eta
self.eps = eps
self.clip = clip
self.exclude_bias_n_norm = exclude_bias_n_norm
# transfer optim methods
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__ # type: ignore
self.__getstate__ = self.optim.__getstate__ # type: ignore
self.__repr__ = self.optim.__repr__ # type: ignore
@property
def defaults(self):
return self.optim.defaults
@defaults.setter
def defaults(self, defaults):
self.optim.defaults = defaults
@property # type: ignore
def __class__(self):
return Optimizer
@property
def state(self):
return self.optim.state
@state.setter
def state(self, state):
self.optim.state = state
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
@torch.no_grad()
def step(self, closure=None):
weight_decays = []
for group in self.optim.param_groups:
weight_decay = group.get("weight_decay", 0)
weight_decays.append(weight_decay)
# reset weight decay
group["weight_decay"] = 0
# update the parameters
for p in group["params"]:
if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
self.update_p(p, group, weight_decay)
# update the optimizer
self.optim.step(closure=closure)
# return weight decay control to optimizer
for group_idx, group in enumerate(self.optim.param_groups):
group["weight_decay"] = weight_decays[group_idx]
def update_p(self, p, group, weight_decay):
# calculate new norms
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)
if p_norm != 0 and g_norm != 0:
# calculate new lr
new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)
# clip lr
if self.clip:
new_lr = min(new_lr / group["lr"], 1)
# update params with clipped lr
p.grad.data += weight_decay * p.data
p.grad.data *= new_lr
================================================
FILE: loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class contrastive_loss(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self,x,labels):
#this function assums that positive logit is always the first element.
#Which is true here
loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1)
return loss.mean()
class SimCLR(nn.Module):
def __init__(self,temperature=0.5,n_views=2,contrastive=False):
super(SimCLR,self).__init__()
self.temp = temperature
self.n_views = n_views
if contrastive:
self.criterion = contrastive_loss()
else:
self.criterion = torch.nn.CrossEntropyLoss()
def info_nce_loss(self,X):
bs, n_dim = X.shape
bs = int(bs/self.n_views)
device = X.device
labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(device)
similarity_matrix = torch.matmul(X, X.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# select only the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
logits = logits / self.temp
return logits, labels
def forward(self,X):
logits, labels = self.info_nce_loss(X)
loss = self.criterion(logits, labels)
return loss
class Z_loss(nn.Module):
def __init__(self,):
super().__init__()
pass
def forward(self,z):
z_list = z.chunk(2,dim=0)
z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean()
z_sim_out = z_sim.clone().detach()
return -z_sim, z_sim_out
class TotalCodingRate(nn.Module):
def __init__(self, eps=0.01):
super(TotalCodingRate, self).__init__()
self.eps = eps
def compute_discrimn_loss(self, W):
"""Discriminative Loss."""
p, m = W.shape #[d, B]
I = torch.eye(p,device=W.device)
scalar = p / (m * self.eps)
logdet = torch.logdet(I + scalar * W.matmul(W.T))
return logdet / 2.
def forward(self,X):
return - self.compute_discrimn_loss(X.T)
class MaximalCodingRateReduction(torch.nn.Module):
def __init__(self, eps=0.01, gamma=1):
super(MaximalCodingRateReduction, self).__init__()
self.eps = eps
self.gamma = gamma
def compute_discrimn_loss(self, W):
"""Discriminative Loss."""
p, m = W.shape
I = torch.eye(p,device=W.device)
scalar = p / (m * self.eps)
logdet = torch.logdet(I + scalar * W.matmul(W.T))
return logdet / 2.
def compute_compress_loss(self, W, Pi):
p, m = W.shape
k, _, _ = Pi.shape
I = torch.eye(p,device=W.device).expand((k,p,p))
trPi = Pi.sum(2) + 1e-8
scale = (p/(trPi*self.eps)).view(k,1,1)
W = W.view((1,p,m))
log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2)))
compress_loss = (trPi.squeeze()*log_det/(2*m)).sum()
return compress_loss
def forward(self, X, Y, num_classes=None):
#This function support Y as label integer or membership probablity.
if len(Y.shape)==1:
#if Y is a label vector
if num_classes is None:
num_classes = Y.max() + 1
Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device)
for indx, label in enumerate(Y):
Pi[label,0,indx] = 1
else:
#if Y is a probility matrix
if num_classes is None:
num_classes = Y.shape[1]
Pi = Y.T.reshape((num_classes,1,-1))
W = X.T
discrimn_loss = self.compute_discrimn_loss(W)
compress_loss = self.compute_compress_loss(W, Pi)
total_loss = - discrimn_loss + self.gamma*compress_loss
return total_loss, [discrimn_loss.item(), compress_loss.item()]
================================================
FILE: main.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as FF
from tqdm import tqdm
import torch
from torchvision.datasets import CIFAR10
from loss import TotalCodingRate
from func import chunk_avg
from lars import LARS, LARSWrapper
from func import WeightedKNNClassifier
import torch.optim.lr_scheduler as lr_scheduler
from torch.cuda.amp import GradScaler, autocast
######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Unsupervised Learning')
parser.add_argument('--patch_sim', type=int, default=200,
help='coefficient of cosine similarity (default: 200)')
parser.add_argument('--tcr', type=int, default=1,
help='coefficient of tcr (default: 1)')
parser.add_argument('--num_patches', type=int, default=100,
help='number of patches used in EMP-SSL (default: 100)')
parser.add_argument('--arch', type=str, default="resnet18-cifar",
help='network architecture (default: resnet18-cifar)')
parser.add_argument('--bs', type=int, default=100,
help='batch size (default: 100)')
parser.add_argument('--lr', type=float, default=0.3,
help='learning rate (default: 0.3)')
parser.add_argument('--eps', type=float, default=0.2,
help='eps for TCR (default: 0.2)')
parser.add_argument('--msg', type=str, default="NONE",
help='additional message for description (default: NONE)')
parser.add_argument('--dir', type=str, default="EMP-SSL-Training",
help='directory name (default: EMP-SSL-Training)')
parser.add_argument('--data', type=str, default="cifar10",
help='data (default: cifar10)')
parser.add_argument('--epoch', type=int, default=30,
help='max number of epochs to finish (default: 30)')
args = parser.parse_args()
print(args)
num_patches = args.num_patches
dir_name = f"./logs/{args.dir}/patchsim{args.patch_sim}_numpatch{args.num_patches}_bs{args.bs}_lr{args.lr}_{args.msg}"
#####################
## Helper Function ##
#####################
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
class Similarity_Loss(nn.Module):
def __init__(self, ):
super().__init__()
pass
def forward(self, z_list, z_avg):
z_sim = 0
num_patch = len(z_list)
z_list = torch.stack(list(z_list), dim=0)
z_avg = z_list.mean(dim=0)
z_sim = 0
for i in range(num_patch):
z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
z_sim = z_sim/num_patch
z_sim_out = z_sim.clone().detach()
return -z_sim, z_sim_out
def cal_TCR(z, criterion, num_patches):
z_list = z.chunk(num_patches,dim=0)
loss = 0
for i in range(num_patches):
loss += criterion(z_list[i])
loss = loss/num_patches
return loss
######################
## Prepare Training ##
######################
torch.multiprocessing.set_sharing_strategy('file_system')
if args.data == "imagenet100" or args.data == "imagenet":
train_dataset = load_dataset("imagenet", train=True, num_patch = num_patches)
dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=8)
else:
train_dataset = load_dataset(args.data, train=True, num_patch = num_patches)
dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=16)
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
net.cuda()
opt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
opt = LARSWrapper(opt,eta=0.005,clip=True,exclude_bias_n_norm=True,)
scaler = GradScaler()
if args.data == "imagenet-100":
num_converge = (150000//args.bs)*args.epoch
else:
num_converge = (50000//args.bs)*args.epoch
scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=num_converge, eta_min=0,last_epoch=-1)
# Loss
contractive_loss = Similarity_Loss()
criterion = TotalCodingRate(eps=args.eps)
##############
## Training ##
##############
def main():
for epoch in range(args.epoch):
for step, (data, label) in tqdm(enumerate(dataloader)):
net.zero_grad()
opt.zero_grad()
data = torch.cat(data, dim=0)
data = data.cuda()
z_proj = net(data)
z_list = z_proj.chunk(num_patches, dim=0)
z_avg = chunk_avg(z_proj, num_patches)
#Contractive Loss
loss_contract, _ = contractive_loss(z_list, z_avg)
loss_TCR = cal_TCR(z_proj, criterion, num_patches)
loss = args.patch_sim*loss_contract + args.tcr*loss_TCR
loss.backward()
opt.step()
scheduler.step()
model_dir = dir_name+"/save_models/"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
torch.save(net.state_dict(), model_dir+str(epoch)+".pt")
print("At epoch:", epoch, "loss similarity is", loss_contract.item(), ",loss TCR is:", (loss_TCR).item(), "and learning rate is:", opt.param_groups[0]['lr'])
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
================================================
FILE: mcr/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class MCRGANloss(nn.Module):
def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000, mode=0):
super(MCRGANloss, self).__init__()
self.num_class = numclasses
self.train_mode = mode
self.gam1 = gam1
self.gam2 = gam2
self.gam3 = gam3
self.eps = eps
def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
# t = time.time()
errD, empi = self.old_version(Z, Z_bar, real_label, ith_inner_loop, num_inner_loop)
return errD, empi
def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
if self.train_mode == 2:
loss_z, _ = self.deltaR(Z, real_label, self.num_class)
assert num_inner_loop >= 2
if (ith_inner_loop + 1) % num_inner_loop != 0:
# print(f"{ith_inner_loop + 1}/{num_inner_loop}")
# print("calculate delta R(z)")
return loss_z, None
loss_h, _ = self.deltaR(Z_bar, real_label, self.num_class)
empi = [loss_z, loss_h]
term3 = 0.
for i in range(self.num_class):
new_Z = torch.cat((Z[real_label == i], Z_bar[real_label == i]), 0)
new_label = torch.cat(
(torch.zeros_like(real_label[real_label == i]),
torch.ones_like(real_label[real_label == i]))
)
loss, em = self.deltaR(new_Z, new_label, 2)
term3 += loss
empi = empi + [term3]
errD = self.gam1 * loss_z + self.gam2 * loss_h + self.gam3 * term3
elif self.train_mode == 1:
print("has been dropped")
raise NotImplementedError()
elif self.train_mode == 0:
new_Z = torch.cat((Z, Z_bar), 0)
new_label = torch.cat((torch.zeros_like(real_label), torch.ones_like(real_label)))
errD, empi = self.deltaR(new_Z, new_label, 2)
else:
raise ValueError()
return errD, empi
def debug(self, Z, Z_bar, real_label):
print("===========================")
def compute_discrimn_loss(self, Z):
"""Theoretical Discriminative Loss."""
d, n = Z.shape
I = torch.eye(d).to(Z.device)
scalar = d / (n * self.eps)
logdet = torch.logdet(I + scalar * Z @ Z.T)
return logdet / 2.
def compute_compress_loss(self, Z, Pi):
"""Theoretical Compressive Loss."""
d, n = Z.shape
I = torch.eye(d).to(Z.device)
compress_loss = []
scalars = []
for j in range(Pi.shape[1]):
Z_ = Z[:, Pi[:, j] == 1]
trPi = Pi[:, j].sum() + 1e-8
scalar = d / (trPi * self.eps)
log_det = torch.logdet(I + scalar * Z_ @ Z_.T)
compress_loss.append(log_det)
scalars.append(trPi / (2 * n))
return compress_loss, scalars
def deltaR(self, Z, Y, num_classes):
if num_classes is None:
num_classes = Y.max() + 1
#print("classes:", num_classes)
Pi = F.one_hot(Y, num_classes).to(Z.device)
discrimn_loss = self.compute_discrimn_loss(Z.T)
compress_loss, scalars = self.compute_compress_loss(Z.T, Pi)
compress_term = 0.
for z, s in zip(compress_loss, scalars):
compress_term += s * z
total_loss = discrimn_loss - compress_term
return -total_loss, (discrimn_loss, compress_term, compress_loss, scalars)
def gumb_compress_loss(self, Z, P):
d, n = Z.shape
I = torch.eye(d).to(Z.device)
compress_loss = 0.
for j in range(self.num_class):
#P[:, j:j+1][P[:, j:j+1]