Repository: tsb0601/EMP-SSL
Branch: main
Commit: 8fc3758617e0
Files: 13
Total size: 66.6 KB
Directory structure:
gitextract_sphqay9s/
├── README.md
├── dataset/
│ ├── aug.py
│ ├── aug4img.py
│ └── datasets.py
├── evaluate.py
├── func.py
├── lars.py
├── loss.py
├── main.py
├── mcr/
│ └── loss.py
├── model/
│ ├── model.py
│ └── resnet.py
└── requirements.text
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch
[](https://arxiv.org/abs/2304.03977)

Authors: Shengbang Tong*, Yubei Chen*, Yi Ma, Yann LeCun
## Introduction
This repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch." The paper introduces a simplistic but efficient self-supervised learning method called Extreme-Multi-Patch Self-Supervised-Learning (EMP-SSL). EMP-SSL significantly reduces the training epochs required for convergence by increasing the number of fix size image patches from each image instance.
## Preparing Training Data
Cifar10 and Cifar100 can be downloaded automatically in the script. ImageNet100 is a special subset of ImageNet. Details can be found in this [link](https://github.com/HobbitLong/CMC/issues/21).
## Getting Started
Current code implementation supports Cifar10, Cifar100 and ImageNet100.
To get started with the EMP-SSL implementation, follow these instructions:
### 1. Clone this repository
```bash
git clone https://github.com/tsb0601/emp-ssl.git
cd emp-ssl
```
### 2. Install required packages
```
pip install -r requirements.txt
```
### 3. Training
#### Reproducing 1-epoch results
| | CIFAR-10<br>1 Epoch | CIFAR-100<br>1 Epoch | Tiny ImageNet<br>1 epochs | ImageNet-100<br>1 epochs |
|--------------------|:----------------------:|:-----------------------:|:----------------------------:|:--------------------------:|
| EMP-SSL (1 Epoch) | 0.842 | 0.585 | 0.381 | 0.585 |
For CIFAR10 or CIFAR100
```
python main.py --data cifar10 --epoch 2 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```
For ImageNet100
```
python main.py --data imagenet100 --epoch 2 --patch_sim 200 --arch 'resnet18-imagenet' --num_patches 20 --lr 0.3
```
#### Reproducing multi epochs results
| | CIFAR-10<br>1 Epoch | CIFAR-10<br>10 Epochs | CIFAR-10<br>30 Epochs | CIFAR-10<br>1000 Epochs | CIFAR-100<br>1 Epoch | CIFAR-100<br>10 Epochs | CIFAR-100<br>30 Epochs | CIFAR-100<br>1000 Epochs | Tiny ImageNet<br>10 Epochs | Tiny ImageNet<br>1000 Epochs |ImageNet-100<br>10 Epochs | ImageNet-100<br>400 Epochs |
|----------------------|:-------------------:|:---------------------:|:---------------------:|:-----------------------:|:--------------------:|:----------------------:|:----------------------:|:------------------------:| :------------------------:|:------------------------:|:------------------------:| :------------------------:|
| SimCLR | 0.282 | 0.565 | 0.663 | 0.910 | 0.054 | 0.185 | 0.341 | 0.662 | - | 0.488 | - | 0.776
| BYOL | 0.249 | 0.489 | 0.684 | 0.926 | 0.043 | 0.150 | 0.349 | 0.708 | - | 0.510 | - | 0.802
| VICReg | 0.406 | 0.697 | 0.781 | 0.921 | 0.079 | 0.319 | 0.479 | 0.685 | - | - | - | 0.792
| SwAV | 0.245 | 0.532 | 0.767 | 0.923 | 0.028 | 0.208 | 0.294 | 0.658 |- | - | - | 0.740
| ReSSL | 0.245 | 0.256 | 0.525 | 0.914 | 0.033 | 0.122 | 0.247 | 0.674 |- | - | - | 0.769
| EMP-SSL (20 patches) | 0.806 | 0.907 | 0.931 | - | 0.551 | 0.678 | 0.724 | - | - | - | - | -
| EMP-SSL (200 patches)| 0.826* | 0.915 | 0.934 | - | 0.577 | 0.701 | 0.733 | - | 0.515 | - | 0.789 | -
\* Here, we change learning rate schedule to decay in 30 epochs, so 1 epoch accuracy will be slightly lower than optimizing for 1-epoch training.
Change num_patches here to change the number of patches used in EMP-SSL training.
```
python main.py --data cifar10 --epoch 30 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3
```
### 4. Evaluating
Because our model is trained with only fixed size image patches. To evaluate the performance, we adopt bag-of-features model from intra-instance VICReg paper. Change test_patches here to adjust number of patches used in bag-of-feature model for different GPUs.
```
python evaluate.py --model_path 'path to your evaluated model' --test_patches 128
```
## Acknowledgment
This repo is inspired by [MCR2](https://github.com/Ma-Lab-Berkeley/MCR2), [solo-learn](https://github.com/vturrisi/solo-learn) and [NMCE](https://github.com/zengyi-li/NMCE-release) repo.
## Citation
If you find this repository useful, please consider giving a star :star: and citation:
```
@article{tong2023empssl,
title={EMP-SSL: Towards Self-Supervised Learning in One Training Epoch},
author={Shengbang Tong and Yubei Chen and Yi Ma and Yann Lecun},
journal={arXiv preprint arXiv:2304.03977},
year={2023}
}
```
================================================
FILE: dataset/aug.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
def load_transforms(name):
"""Load data transformations.
Note:
- Gaussian Blur is defined at the bottom of this file.
"""
_name = name.lower()
if _name == "cifar_sup":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.RandomCrop(32, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.ToTensor(),normalize])
elif _name == "cifar_patch":
normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
aug_transform = transforms.Compose([
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.ToTensor(), normalize])
elif _name == "cifar_simclr_norm":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
transforms.ToTensor(),normalize])
elif _name == "cifar_byol":
normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(
(32, 32),
scale=(0.2, 1.0),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([Solarization()], p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
normalize
])
baseline_transform = transforms.Compose([
# transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)),
transforms.ToTensor(),normalize])
else:
raise NameError("{} not found in transform loader".format(name))
return aug_transform, baseline_transform
class Solarization:
"""Solarization as a callable object."""
def __call__(self, img: Image) -> Image:
"""Applies solarization to an input image.
Args:
img (Image): an image in the PIL.Image format.
Returns:
Image: a solarized image.
"""
return ImageOps.solarize(img)
class GBlur(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if np.random.rand() < self.p:
sigma = np.random.rand() * 1.9 + 0.1
return img.filter(ImageFilter.GaussianBlur(sigma))
else:
return img
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ContrastiveLearningViewGenerator(object):
def __init__(self, num_patch = 4):
self.num_patch = num_patch
def __call__(self, x):
normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GBlur(p=0.1),
transforms.RandomApply([Solarization()], p=0.1),
transforms.ToTensor(),
normalize
])
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
return augmented_x
================================================
FILE: dataset/aug4img.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
from torchvision.transforms import InterpolationMode
class Solarization:
"""Solarization as a callable object."""
def __call__(self, img: Image) -> Image:
"""Applies solarization to an input image.
Args:
img (Image): an image in the PIL.Image format.
Returns:
Image: a solarized image.
"""
return ImageOps.solarize(img)
class GBlur(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if np.random.rand() < self.p:
sigma = np.random.rand() * 1.9 + 0.1
return img.filter(ImageFilter.GaussianBlur(sigma))
else:
return img
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ContrastiveLearningViewGenerator(object):
def __init__(self, num_patch = 4):
self.num_patch = num_patch
def __call__(self, x):
aug_transform = transforms.Compose([
transforms.RandomResizedCrop(
224, scale=(0.25, 0.25), interpolation=InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GBlur(p=0.1),
transforms.RandomApply([Solarization()], p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
return augmented_x
================================================
FILE: dataset/datasets.py
================================================
import os
import numpy as np
import torchvision
def load_dataset(data_name, train=True, num_patch = 4, path="./data/"):
"""Loads a dataset for training and testing. If augmentloader is used, transform should be None.
Parameters:
data_name (str): name of the dataset
transform_name (torchvision.transform): name of transform to be applied (see aug.py)
use_baseline (bool): use baseline transform or augmentation transform
train (bool): load training set or not
contrastive (bool): whether to convert transform to multiview augmentation for contrastive learning.
n_views (bool): number of views for contrastive learning
path (str): path to dataset base path
Returns:
dataset (torch.data.dataset)
"""
_name = data_name.lower()
if _name == "imagenet":
from .aug4img import ContrastiveLearningViewGenerator
else:
from .aug import ContrastiveLearningViewGenerator
transform = ContrastiveLearningViewGenerator(num_patch = num_patch)
if _name == "cifar10":
trainset = torchvision.datasets.CIFAR10(root=os.path.join(path, "CIFAR10"), train=train, download=True, transform=transform)
trainset.num_classes = 10
elif _name == "cifar100":
trainset = torchvision.datasets.CIFAR100(root=os.path.join(path, "CIFAR100"), train=train, download=True, transform=transform)
trainset.num_classes = 100
elif _name == "imagenet":
if train:
trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/train100/",transform=transform)
#trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/train/",transform=transform)
else:
trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/val100/",transform=transform)
#trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/val/",transform=transform)
trainset.num_classes = 200
else:
raise NameError("{} not found in trainset loader".format(_name))
return trainset
def sparse2coarse(targets):
"""CIFAR100 Coarse Labels. """
coarse_targets = [ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3,
9, 7, 11, 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 0, 11, 1, 10,
12, 14, 16, 9, 11, 5, 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16,
4, 17, 4, 2, 0, 17, 4, 18, 17, 10, 3, 2, 12, 12, 16, 12, 1,
9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 16, 19, 2, 4, 6,
19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]
return np.array(coarse_targets)[targets]
================================================
FILE: evaluate.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import torch
import numpy as np
from func import WeightedKNNClassifier, linear
######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--test_patches', type=int, default=128,
help='number of patches used in testing (default: 128)')
parser.add_argument('--data', type=str, default="cifar10",
help='dataset (default: cifar10)')
parser.add_argument('--arch', type=str, default="resnet18-cifar",
help='network architecture (default: resnet18-cifar)')
parser.add_argument('--lr', type=float, default=0.03,
help='learning rate for linear eval (default: 0.03)')
parser.add_argument('--linear', type=bool, default=True,
help='use linear eval or not')
parser.add_argument('--knn', help='evaluate using kNN measuring cosine similarity', action='store_true')
parser.add_argument('--model_path', type=str, default="",
help='model directory for eval')
args = parser.parse_args()
######################
## Testing Accuracy ##
######################
test_patches = args.test_patches
def compute_accuracy(y_pred, y_true):
"""Compute accuracy by counting correct classification. """
assert y_pred.shape == y_true.shape
return 1 - np.count_nonzero(y_pred - y_true) / y_true.size
knn_classifier = WeightedKNNClassifier()
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
def test(net, train_loader, test_loader):
train_z_full_list, train_y_list, test_z_full_list, test_y_list = [], [], [], []
with torch.no_grad():
for x, y in tqdm(train_loader):
x = torch.cat(x, dim = 0)
z_proj, z_pre = net(x, is_test=True)
z_pre = chunk_avg(z_pre, test_patches)
z_pre = z_pre.detach().cpu()
train_z_full_list.append(z_pre)
knn_classifier.update(train_features = z_pre, train_targets = y)
train_y_list.append(y)
for x, y in tqdm(test_loader):
x = torch.cat(x, dim = 0)
z_proj, z_pre = net(x, is_test=True)
z_pre = chunk_avg(z_pre, test_patches)
z_pre = z_pre.detach().cpu()
test_z_full_list.append(z_pre)
knn_classifier.update(test_features = z_pre, test_targets = y)
test_y_list.append(y)
train_features_full, train_labels, test_features_full, test_labels = torch.cat(train_z_full_list,dim=0), torch.cat(train_y_list,dim=0), torch.cat(test_z_full_list,dim=0), torch.cat(test_y_list,dim=0)
if args.data == "cifar10":
num_classes = 10
elif args.data == "cifar100":
num_classes = 100
elif args.data == "tinyimagenet200":
num_classes = 200
elif args.data == "imagenet100":
num_classes = 100
elif args.data == "imagenet":
num_classes = 1000
if args.linear:
print("Using Linear Eval to evaluate accuracy")
linear(train_features_full, train_labels, test_features_full, test_labels, lr=args.lr, num_classes = num_classes)
if args.knn:
print("Using KNN to evaluate accuracy")
top1, top5 = knn_classifier.compute()
print("KNN (top1/top5):", top1, top5)
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
torch.multiprocessing.set_sharing_strategy('file_system')
#Get Dataset
if args.data == "imagenet100" or args.data == "imagenet":
memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)
test_data = load_dataset(args.data, train=False, num_patch = test_patches)
test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)
else:
memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches)
memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8)
test_data = load_dataset(args.data, train=False, num_patch = test_patches)
test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8)
# Load Model and Checkpoint
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
save_dict = torch.load(args.model_path)
net.load_state_dict(save_dict,strict=False)
net.cuda()
net.eval()
test(net, memory_loader, test_loader)
================================================
FILE: func.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import torchvision
# import torch.nn
from torch import nn, optim
import torch.nn as nn
from torch.utils import data
from torch.utils.data import DataLoader
from typing import Tuple
import torch.nn.functional as F
from torchmetrics.metric import Metric
class WeightedKNNClassifier(Metric):
def __init__(
self,
k: int = 20,
T: float = 0.07,
max_distance_matrix_size: int = int(5e6),
distance_fx: str = "cosine",
epsilon: float = 0.00001,
dist_sync_on_step: bool = False,
):
"""Implements the weighted k-NN classifier used for evaluation.
Args:
k (int, optional): number of neighbors. Defaults to 20.
T (float, optional): temperature for the exponential. Only used with cosine
distance. Defaults to 0.07.
max_distance_matrix_size (int, optional): maximum number of elements in the
distance matrix. Defaults to 5e6.
distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
"euclidean". Defaults to "cosine".
epsilon (float, optional): Small value for numerical stability. Only used with
euclidean distance. Defaults to 0.00001.
dist_sync_on_step (bool, optional): whether to sync distributed values at every
step. Defaults to False.
"""
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
self.k = k
self.T = T
self.max_distance_matrix_size = max_distance_matrix_size
self.distance_fx = distance_fx
self.epsilon = epsilon
self.add_state("train_features", default=[], persistent=False)
self.add_state("train_targets", default=[], persistent=False)
self.add_state("test_features", default=[], persistent=False)
self.add_state("test_targets", default=[], persistent=False)
def update(
self,
train_features: torch.Tensor = None,
train_targets: torch.Tensor = None,
test_features: torch.Tensor = None,
test_targets: torch.Tensor = None,
):
"""Updates the memory banks. If train (test) features are passed as input, the
corresponding train (test) targets must be passed as well.
Args:
train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
"""
assert (train_features is None) == (train_targets is None)
assert (test_features is None) == (test_targets is None)
if train_features is not None:
assert train_features.size(0) == train_targets.size(0)
self.train_features.append(train_features.detach())
self.train_targets.append(train_targets.detach())
if test_features is not None:
assert test_features.size(0) == test_targets.size(0)
self.test_features.append(test_features.detach())
self.test_targets.append(test_targets.detach())
def set_tk(self, T, k):
self.T = T
self.k = k
@torch.no_grad()
def compute(self) -> Tuple[float]:
"""Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
the weight is computed using the exponential of the temperature scaled cosine
distance of the samples. If euclidean distance is selected, the weight corresponds
to the inverse of the euclidean distance.
Returns:
Tuple[float]: k-NN accuracy @1 and @5.
"""
#print(self.T, self.k)
train_features = torch.cat(self.train_features)
train_targets = torch.cat(self.train_targets)
test_features = torch.cat(self.test_features)
test_targets = torch.cat(self.test_targets)
if self.distance_fx == "cosine":
train_features = F.normalize(train_features)
test_features = F.normalize(test_features)
num_classes = torch.unique(test_targets).numel()
num_train_images = train_targets.size(0)
num_test_images = test_targets.size(0)
num_train_images = train_targets.size(0)
chunk_size = min(
max(1, self.max_distance_matrix_size // num_train_images),
num_test_images,
)
k = min(self.k, num_train_images)
top1, top5, total = 0.0, 0.0, 0
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
for idx in range(0, num_test_images, chunk_size):
# get the features for test images
features = test_features[idx : min((idx + chunk_size), num_test_images), :]
targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
batch_size = targets.size(0)
# calculate the dot product and compute top-k neighbors
if self.distance_fx == "cosine":
similarities = torch.mm(features, train_features.t())
elif self.distance_fx == "euclidean":
similarities = 1 / (torch.cdist(features, train_features) + self.epsilon)
else:
raise NotImplementedError
similarities, indices = similarities.topk(k, largest=True, sorted=True)
candidates = train_targets.view(1, -1).expand(batch_size, -1)
retrieved_neighbors = torch.gather(candidates, 1, indices)
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
if self.distance_fx == "cosine":
similarities = similarities.clone().div_(self.T).exp_()
probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
similarities.view(batch_size, -1, 1),
),
1,
)
_, predictions = probs.sort(1, True)
# find the predictions that match the target
correct = predictions.eq(targets.data.view(-1, 1))
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
top5 = (
top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()
) # top5 does not make sense if k < 5
total += targets.size(0)
top1 = top1 * 100.0 / total
top5 = top5 * 100.0 / total
self.reset()
return top1, top5
def linear(train_features, train_labels, test_features, test_labels, lr=0.0075, num_classes = 100):
train_data = tensor_dataset(train_features,train_labels)
test_data = tensor_dataset(test_features,test_labels)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=100, shuffle=True, drop_last=False, num_workers=2)
LL = nn.Linear(train_features.shape[1],num_classes)
optimizer = torch.optim.SGD(LL.parameters(), lr=lr, momentum=0.9, weight_decay=5e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
criterion = torch.nn.CrossEntropyLoss()
test_acc_list = []
for epoch in range(100):
top1_train_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(train_loader):
x_batch = x_batch
y_batch = y_batch
logits = LL(x_batch)
loss = criterion(logits, y_batch)
top1 = accuracy(logits, y_batch, topk=(1,))
top1_train_accuracy += top1[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
top1_train_accuracy /= (counter + 1)
top1_accuracy = 0
top5_accuracy = 0
for counter, (x_batch, y_batch) in enumerate(test_loader):
x_batch = x_batch
y_batch = y_batch
logits = LL(x_batch)
top1, top5 = accuracy(logits, y_batch, topk=(1,5))
top1_accuracy += top1[0]
top5_accuracy += top5[0]
top1_accuracy /= (counter + 1)
top5_accuracy /= (counter + 1)
test_acc_list.append(top1_accuracy)
print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")
acc_vect = torch.tensor(test_acc_list)
print('best linear test acc {}, last acc {}'.format(acc_vect.max().item(),acc_vect[-1].item()))
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class tensor_dataset(data.Dataset):
def __init__(self,x,y):
self.x = x
self.y = y
self.length = x.shape[0]
def __getitem__(self,indx):
return self.x[indx], self.y[indx]
def __len__(self):
return self.length
def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0.,gamma_max=1.0):
warmup_start = total_epoch - warmup_epoch
warmup_end = total_epoch
if warmup_start < epoch<=warmup_end:
loss_fn.gamma = ((epoch - warmup_start)/(warmup_end - warmup_start))*(gamma_max - gamma_min) + gamma_min
else:
loss_fn.gamma = gamma_min
def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):
if epoch<warmup_epoch:
optimizer.param_groups[0]['lr'] = base_lr*min(1.,(epoch+1)/warmup_epoch)
def marginal_H(logits):
bs = torch.tensor(logits.shape[0]).float()
logps = torch.log_softmax(logits,dim=1)
marginal_p = torch.logsumexp(logps - bs.log(),dim=0)
H = (marginal_p.exp()*(-marginal_p)).sum()*(1.4426950)
return H
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
def cluster_match(cluster_mtx,label_mtx,n_classes=10,print_result=True):
#verified to be consistent to optimimal assignment problem based algorithm
cluster_indx = list(cluster_mtx.unique())
assigned_label_list = []
assigned_count = []
while (len(assigned_label_list)<=n_classes) and len(cluster_indx)>0:
max_label_list = []
max_count_list = []
for indx in cluster_indx:
#calculate highest number of matchs
mask = cluster_mtx==indx
label_elements, counts = label_mtx[mask].unique(return_counts=True)
for assigned_label in assigned_label_list:
counts[label_elements==assigned_label] = 0
max_count_list.append(counts.max())
max_label_list.append(label_elements[counts.argmax()])
max_label = torch.stack(max_label_list)
max_count = torch.stack(max_count_list)
assigned_label_list.append(max_label[max_count.argmax()])
assigned_count.append(max_count.max())
cluster_indx.pop(max_count.argmax())
total_correct = torch.tensor(assigned_count).sum().item()
total_sample = cluster_mtx.shape[0]
acc = total_correct/total_sample
if print_result:
print('{}/{} ({}%) correct'.format(total_correct,total_sample,acc*100))
else:
return total_correct, total_sample, acc
def cluster_merge_match(cluster_mtx,label_mtx,print_result=True):
cluster_indx = list(cluster_mtx.unique())
n_correct = 0
for cluster_id in cluster_indx:
label_elements, counts = label_mtx[cluster_mtx==cluster_id].unique(return_counts=True)
n_correct += counts.max()
total_sample = len(cluster_mtx)
acc = n_correct.item()/total_sample
if print_result:
print('{}/{} ({}%) correct'.format(n_correct,total_sample,acc*100))
else:
return n_correct, total_sample, acc
def cluster_acc(test_loader,net,device,print_result=False,save_name_img='cluster_img',save_name_fig='pca_figure'):
cluster_list = []
label_list = []
x_list = []
z_list = []
net.eval()
for x, y in test_loader:
with torch.no_grad():
x, y = x.float().to(device), y.to(device)
z, logit = net(x)
if logit.sum() == 0:
logit += torch.randn_like(logit)
cluster_list.append(logit.max(dim=1)[1].cpu())
label_list.append(y.cpu())
x_list.append(x.cpu())
z_list.append(z.cpu())
net.train()
cluster_mtx = torch.cat(cluster_list,dim=0)
label_mtx = torch.cat(label_list,dim=0)
x_mtx = torch.cat(x_list,dim=0)
z_mtx = torch.cat(z_list,dim=0)
_, _, acc_single = cluster_match(cluster_mtx,label_mtx,n_classes=label_mtx.max()+1,print_result=False)
_, _, acc_merge = cluster_merge_match(cluster_mtx,label_mtx,print_result=False)
NMI = normalized_mutual_info_score(label_mtx.numpy(),cluster_mtx.numpy())
ARI = adjusted_rand_score(label_mtx.numpy(),cluster_mtx.numpy())
if print_result:
print('cluster match acc {}, cluster merge match acc {}, NMI {}, ARI {}'.format(acc_single,acc_merge,NMI,ARI))
save_name_img += '_acc'+ str(acc_single)[2:5]
save_cluster_imgs(cluster_mtx,x_mtx,save_name_img)
save_latent_pca_figure(z_mtx,cluster_mtx,save_name_fig)
return acc_single, acc_merge, NMI, ARI
def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):
cluster_indexs, counts = cluster_mtx.unique(return_counts=True)
x_list = []
counts_list = []
for i, c_indx in enumerate(cluster_indexs):
if counts[i]>npercluster:
x_list.append(x_mtx[cluster_mtx==c_indx,:,:,:])
counts_list.append(counts[i])
n_clusters = len(counts_list)
fig, ax = plt.subplots(n_clusters,1,dpi=80,figsize=(1.2*n_clusters, 3*n_clusters))
for i, ax in enumerate(ax):
img = torchvision.utils.make_grid(x_list[i][:npercluster],nrow=npercluster//5,normalize=True)
ax.imshow(img.permute(1,2,0))
ax.set_axis_off()
ax.set_title('Cluster with {} images'.format(counts_list[i]))
fig.savefig(save_name+'.pdf')
plt.close(fig)
def save_latent_pca_figure(z_mtx,cluster_mtx,save_name):
_, s_z_all, _ = z_mtx.svd()
cluster_n = []
cluster_s = []
for cluster_indx in cluster_mtx.unique():
_, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
cluster_n.append((cluster_mtx==cluster_indx).sum().item())
cluster_s.append(s_cluster/s_cluster.max())
#make plot
fig, ax = plt.subplots(1,2,figsize=(9, 3))
ax[0].plot(s_z_all)
for i, s_curve in enumerate(cluster_s):
ax[1].plot(s_curve,label=cluster_n[i])
ax[1].set_xlim(xmin=0,xmax=20)
ax[1].legend()
fig.savefig(save_name +'.pdf')
plt.close(fig)
def analyze_latent(z_mtx,cluster_mtx):
_, s_z_all, _ = z_mtx.svd()
cluster_n = []
cluster_s = []
cluster_d = []
for cluster_indx in cluster_mtx.unique():
_, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd()
s_cluster = s_cluster/s_cluster.max()
cluster_n.append((cluster_mtx==cluster_indx).sum().item())
cluster_s.append(s_cluster)
# print(list(cluster_s))
print(s_cluster)
# s_diff = s_cluster[:-1] - s_cluster[1:]
# cluster_d.append(s_diff.max(0)[1])
cluster_d.append((s_cluster>0.01).sum())
for i in range(len(cluster_n)):
print('subspace {}, dimension {}, samples {}'.format(i,cluster_d[i],cluster_n[i]))
================================================
FILE: lars.py
================================================
import torch
import torch.optim as optim
from torch.optim.optimizer import Optimizer, required
class LARS(Optimizer):
"""
Layer-wise adaptive rate scaling
- Converted from Tensorflow to Pytorch from:
https://github.com/google-research/simclr/blob/master/lars_optimizer.py
- Based on:
https://github.com/noahgolmant/pytorch-lars
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): base learning rate (\gamma_0)
lr (int): Length / Number of layers we want to apply weight decay, else do not compute
momentum (float, optional): momentum factor (default: 0.9)
use_nesterov (bool, optional): flag to use nesterov momentum (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
("\beta")
eta (float, optional): LARS coefficient (default: 0.001)
- Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
- Large Batch Training of Convolutional Networks:
https://arxiv.org/abs/1708.03888
"""
def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001):
self.epoch = 0
defaults = dict(
lr=lr,
momentum=momentum,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
classic_momentum=classic_momentum,
eta=eta,
len_reduced=len_reduced
)
super(LARS, self).__init__(params, defaults)
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.use_nesterov = use_nesterov
self.classic_momentum = classic_momentum
self.eta = eta
self.len_reduced = len_reduced
def step(self, epoch=None, closure=None):
loss = None
if closure is not None:
loss = closure()
if epoch is None:
epoch = self.epoch
self.epoch += 1
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
eta = group['eta']
learning_rate = group['lr']
# TODO: Hacky
counter = 0
for p in group['params']:
if p.grad is None:
continue
param = p.data
grad = p.grad.data
param_state = self.state[p]
# TODO: This really hacky way needs to be improved.
# Note Excluded are passed at the end of the list to are ignored
if counter < self.len_reduced:
grad += self.weight_decay * param
# Create parameter for the momentum
if "momentum_var" not in param_state:
next_v = param_state["momentum_var"] = torch.zeros_like(
p.data
)
else:
next_v = param_state["momentum_var"]
if self.classic_momentum:
trust_ratio = 1.0
# TODO: implementation of layer adaptation
w_norm = torch.norm(param)
g_norm = torch.norm(grad)
device = g_norm.get_device()
trust_ratio = torch.where(w_norm.ge(0), torch.where(
g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)),
torch.Tensor([1.0]).to(device)).item()
scaled_lr = learning_rate * trust_ratio
grad_scaled = scaled_lr*grad
next_v.mul_(momentum).add_(grad_scaled)
if self.use_nesterov:
update = (self.momentum * next_v) + (scaled_lr * grad)
else:
update = next_v
p.data.add_(-update)
# Not classic_momentum
else:
next_v.mul_(momentum).add_(grad)
if self.use_nesterov:
update = (self.momentum * next_v) + (grad)
else:
update = next_v
trust_ratio = 1.0
# TODO: implementation of layer adaptation
w_norm = torch.norm(param)
v_norm = torch.norm(update)
device = v_norm.get_device()
trust_ratio = torch.where(w_norm.ge(0), torch.where(
v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)),
torch.Tensor([1.0]).to(device)).item()
scaled_lr = learning_rate * trust_ratio
p.data.add_(-scaled_lr * update)
counter += 1
return loss
#LARSWrapper from solo-learn repo...
class LARSWrapper:
def __init__(
self,
optimizer: Optimizer,
eta: float = 1e-3,
clip: bool = False,
eps: float = 1e-8,
exclude_bias_n_norm: bool = False,
):
"""Wrapper that adds LARS scheduling to any optimizer.
This helps stability with huge batch sizes.
Args:
optimizer (Optimizer): torch optimizer.
eta (float, optional): trust coefficient. Defaults to 1e-3.
clip (bool, optional): clip gradient values. Defaults to False.
eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
Defaults to False.
"""
self.optim = optimizer
self.eta = eta
self.eps = eps
self.clip = clip
self.exclude_bias_n_norm = exclude_bias_n_norm
# transfer optim methods
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__ # type: ignore
self.__getstate__ = self.optim.__getstate__ # type: ignore
self.__repr__ = self.optim.__repr__ # type: ignore
@property
def defaults(self):
return self.optim.defaults
@defaults.setter
def defaults(self, defaults):
self.optim.defaults = defaults
@property # type: ignore
def __class__(self):
return Optimizer
@property
def state(self):
return self.optim.state
@state.setter
def state(self, state):
self.optim.state = state
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
@torch.no_grad()
def step(self, closure=None):
weight_decays = []
for group in self.optim.param_groups:
weight_decay = group.get("weight_decay", 0)
weight_decays.append(weight_decay)
# reset weight decay
group["weight_decay"] = 0
# update the parameters
for p in group["params"]:
if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
self.update_p(p, group, weight_decay)
# update the optimizer
self.optim.step(closure=closure)
# return weight decay control to optimizer
for group_idx, group in enumerate(self.optim.param_groups):
group["weight_decay"] = weight_decays[group_idx]
def update_p(self, p, group, weight_decay):
# calculate new norms
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)
if p_norm != 0 and g_norm != 0:
# calculate new lr
new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)
# clip lr
if self.clip:
new_lr = min(new_lr / group["lr"], 1)
# update params with clipped lr
p.grad.data += weight_decay * p.data
p.grad.data *= new_lr
================================================
FILE: loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class contrastive_loss(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self,x,labels):
#this function assums that positive logit is always the first element.
#Which is true here
loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1)
return loss.mean()
class SimCLR(nn.Module):
def __init__(self,temperature=0.5,n_views=2,contrastive=False):
super(SimCLR,self).__init__()
self.temp = temperature
self.n_views = n_views
if contrastive:
self.criterion = contrastive_loss()
else:
self.criterion = torch.nn.CrossEntropyLoss()
def info_nce_loss(self,X):
bs, n_dim = X.shape
bs = int(bs/self.n_views)
device = X.device
labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(device)
similarity_matrix = torch.matmul(X, X.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape
# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# select only the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
logits = logits / self.temp
return logits, labels
def forward(self,X):
logits, labels = self.info_nce_loss(X)
loss = self.criterion(logits, labels)
return loss
class Z_loss(nn.Module):
def __init__(self,):
super().__init__()
pass
def forward(self,z):
z_list = z.chunk(2,dim=0)
z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean()
z_sim_out = z_sim.clone().detach()
return -z_sim, z_sim_out
class TotalCodingRate(nn.Module):
def __init__(self, eps=0.01):
super(TotalCodingRate, self).__init__()
self.eps = eps
def compute_discrimn_loss(self, W):
"""Discriminative Loss."""
p, m = W.shape #[d, B]
I = torch.eye(p,device=W.device)
scalar = p / (m * self.eps)
logdet = torch.logdet(I + scalar * W.matmul(W.T))
return logdet / 2.
def forward(self,X):
return - self.compute_discrimn_loss(X.T)
class MaximalCodingRateReduction(torch.nn.Module):
def __init__(self, eps=0.01, gamma=1):
super(MaximalCodingRateReduction, self).__init__()
self.eps = eps
self.gamma = gamma
def compute_discrimn_loss(self, W):
"""Discriminative Loss."""
p, m = W.shape
I = torch.eye(p,device=W.device)
scalar = p / (m * self.eps)
logdet = torch.logdet(I + scalar * W.matmul(W.T))
return logdet / 2.
def compute_compress_loss(self, W, Pi):
p, m = W.shape
k, _, _ = Pi.shape
I = torch.eye(p,device=W.device).expand((k,p,p))
trPi = Pi.sum(2) + 1e-8
scale = (p/(trPi*self.eps)).view(k,1,1)
W = W.view((1,p,m))
log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2)))
compress_loss = (trPi.squeeze()*log_det/(2*m)).sum()
return compress_loss
def forward(self, X, Y, num_classes=None):
#This function support Y as label integer or membership probablity.
if len(Y.shape)==1:
#if Y is a label vector
if num_classes is None:
num_classes = Y.max() + 1
Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device)
for indx, label in enumerate(Y):
Pi[label,0,indx] = 1
else:
#if Y is a probility matrix
if num_classes is None:
num_classes = Y.shape[1]
Pi = Y.T.reshape((num_classes,1,-1))
W = X.T
discrimn_loss = self.compute_discrimn_loss(W)
compress_loss = self.compute_compress_loss(W, Pi)
total_loss = - discrimn_loss + self.gamma*compress_loss
return total_loss, [discrimn_loss.item(), compress_loss.item()]
================================================
FILE: main.py
================================================
############
## Import ##
############
import argparse
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader
from model.model import encoder
from dataset.datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as FF
from tqdm import tqdm
import torch
from torchvision.datasets import CIFAR10
from loss import TotalCodingRate
from func import chunk_avg
from lars import LARS, LARSWrapper
from func import WeightedKNNClassifier
import torch.optim.lr_scheduler as lr_scheduler
from torch.cuda.amp import GradScaler, autocast
######################
## Parsing Argument ##
######################
import argparse
parser = argparse.ArgumentParser(description='Unsupervised Learning')
parser.add_argument('--patch_sim', type=int, default=200,
help='coefficient of cosine similarity (default: 200)')
parser.add_argument('--tcr', type=int, default=1,
help='coefficient of tcr (default: 1)')
parser.add_argument('--num_patches', type=int, default=100,
help='number of patches used in EMP-SSL (default: 100)')
parser.add_argument('--arch', type=str, default="resnet18-cifar",
help='network architecture (default: resnet18-cifar)')
parser.add_argument('--bs', type=int, default=100,
help='batch size (default: 100)')
parser.add_argument('--lr', type=float, default=0.3,
help='learning rate (default: 0.3)')
parser.add_argument('--eps', type=float, default=0.2,
help='eps for TCR (default: 0.2)')
parser.add_argument('--msg', type=str, default="NONE",
help='additional message for description (default: NONE)')
parser.add_argument('--dir', type=str, default="EMP-SSL-Training",
help='directory name (default: EMP-SSL-Training)')
parser.add_argument('--data', type=str, default="cifar10",
help='data (default: cifar10)')
parser.add_argument('--epoch', type=int, default=30,
help='max number of epochs to finish (default: 30)')
args = parser.parse_args()
print(args)
num_patches = args.num_patches
dir_name = f"./logs/{args.dir}/patchsim{args.patch_sim}_numpatch{args.num_patches}_bs{args.bs}_lr{args.lr}_{args.msg}"
#####################
## Helper Function ##
#####################
def chunk_avg(x,n_chunks=2,normalize=False):
x_list = x.chunk(n_chunks,dim=0)
x = torch.stack(x_list,dim=0)
if not normalize:
return x.mean(0)
else:
return F.normalize(x.mean(0),dim=1)
class Similarity_Loss(nn.Module):
def __init__(self, ):
super().__init__()
pass
def forward(self, z_list, z_avg):
z_sim = 0
num_patch = len(z_list)
z_list = torch.stack(list(z_list), dim=0)
z_avg = z_list.mean(dim=0)
z_sim = 0
for i in range(num_patch):
z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
z_sim = z_sim/num_patch
z_sim_out = z_sim.clone().detach()
return -z_sim, z_sim_out
def cal_TCR(z, criterion, num_patches):
z_list = z.chunk(num_patches,dim=0)
loss = 0
for i in range(num_patches):
loss += criterion(z_list[i])
loss = loss/num_patches
return loss
######################
## Prepare Training ##
######################
torch.multiprocessing.set_sharing_strategy('file_system')
if args.data == "imagenet100" or args.data == "imagenet":
train_dataset = load_dataset("imagenet", train=True, num_patch = num_patches)
dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=8)
else:
train_dataset = load_dataset(args.data, train=True, num_patch = num_patches)
dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=16)
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
net = encoder(arch = args.arch)
net = nn.DataParallel(net)
net.cuda()
opt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
opt = LARSWrapper(opt,eta=0.005,clip=True,exclude_bias_n_norm=True,)
scaler = GradScaler()
if args.data == "imagenet-100":
num_converge = (150000//args.bs)*args.epoch
else:
num_converge = (50000//args.bs)*args.epoch
scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=num_converge, eta_min=0,last_epoch=-1)
# Loss
contractive_loss = Similarity_Loss()
criterion = TotalCodingRate(eps=args.eps)
##############
## Training ##
##############
def main():
for epoch in range(args.epoch):
for step, (data, label) in tqdm(enumerate(dataloader)):
net.zero_grad()
opt.zero_grad()
data = torch.cat(data, dim=0)
data = data.cuda()
z_proj = net(data)
z_list = z_proj.chunk(num_patches, dim=0)
z_avg = chunk_avg(z_proj, num_patches)
#Contractive Loss
loss_contract, _ = contractive_loss(z_list, z_avg)
loss_TCR = cal_TCR(z_proj, criterion, num_patches)
loss = args.patch_sim*loss_contract + args.tcr*loss_TCR
loss.backward()
opt.step()
scheduler.step()
model_dir = dir_name+"/save_models/"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
torch.save(net.state_dict(), model_dir+str(epoch)+".pt")
print("At epoch:", epoch, "loss similarity is", loss_contract.item(), ",loss TCR is:", (loss_TCR).item(), "and learning rate is:", opt.param_groups[0]['lr'])
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
================================================
FILE: mcr/loss.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class MCRGANloss(nn.Module):
def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000, mode=0):
super(MCRGANloss, self).__init__()
self.num_class = numclasses
self.train_mode = mode
self.gam1 = gam1
self.gam2 = gam2
self.gam3 = gam3
self.eps = eps
def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
# t = time.time()
errD, empi = self.old_version(Z, Z_bar, real_label, ith_inner_loop, num_inner_loop)
return errD, empi
def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
if self.train_mode == 2:
loss_z, _ = self.deltaR(Z, real_label, self.num_class)
assert num_inner_loop >= 2
if (ith_inner_loop + 1) % num_inner_loop != 0:
# print(f"{ith_inner_loop + 1}/{num_inner_loop}")
# print("calculate delta R(z)")
return loss_z, None
loss_h, _ = self.deltaR(Z_bar, real_label, self.num_class)
empi = [loss_z, loss_h]
term3 = 0.
for i in range(self.num_class):
new_Z = torch.cat((Z[real_label == i], Z_bar[real_label == i]), 0)
new_label = torch.cat(
(torch.zeros_like(real_label[real_label == i]),
torch.ones_like(real_label[real_label == i]))
)
loss, em = self.deltaR(new_Z, new_label, 2)
term3 += loss
empi = empi + [term3]
errD = self.gam1 * loss_z + self.gam2 * loss_h + self.gam3 * term3
elif self.train_mode == 1:
print("has been dropped")
raise NotImplementedError()
elif self.train_mode == 0:
new_Z = torch.cat((Z, Z_bar), 0)
new_label = torch.cat((torch.zeros_like(real_label), torch.ones_like(real_label)))
errD, empi = self.deltaR(new_Z, new_label, 2)
else:
raise ValueError()
return errD, empi
def debug(self, Z, Z_bar, real_label):
print("===========================")
def compute_discrimn_loss(self, Z):
"""Theoretical Discriminative Loss."""
d, n = Z.shape
I = torch.eye(d).to(Z.device)
scalar = d / (n * self.eps)
logdet = torch.logdet(I + scalar * Z @ Z.T)
return logdet / 2.
def compute_compress_loss(self, Z, Pi):
"""Theoretical Compressive Loss."""
d, n = Z.shape
I = torch.eye(d).to(Z.device)
compress_loss = []
scalars = []
for j in range(Pi.shape[1]):
Z_ = Z[:, Pi[:, j] == 1]
trPi = Pi[:, j].sum() + 1e-8
scalar = d / (trPi * self.eps)
log_det = torch.logdet(I + scalar * Z_ @ Z_.T)
compress_loss.append(log_det)
scalars.append(trPi / (2 * n))
return compress_loss, scalars
def deltaR(self, Z, Y, num_classes):
if num_classes is None:
num_classes = Y.max() + 1
#print("classes:", num_classes)
Pi = F.one_hot(Y, num_classes).to(Z.device)
discrimn_loss = self.compute_discrimn_loss(Z.T)
compress_loss, scalars = self.compute_compress_loss(Z.T, Pi)
compress_term = 0.
for z, s in zip(compress_loss, scalars):
compress_term += s * z
total_loss = discrimn_loss - compress_term
return -total_loss, (discrimn_loss, compress_term, compress_loss, scalars)
def gumb_compress_loss(self, Z, P):
d, n = Z.shape
I = torch.eye(d).to(Z.device)
compress_loss = 0.
for j in range(self.num_class):
#P[:, j:j+1][P[:, j:j+1]<threshold] = 0
Z_ = Z * P[:, j:j+1]
trPi = P[:, j].sum() + 1e-8
scalar = d / (trPi * self.eps)
log_det = torch.logdet(I + scalar * Z_ @ Z_.T)
compress_loss += (trPi / (2 * n)) *log_det
return compress_loss
def pseudo_label_loss(self, Z, logits, thres = 1.4):
logits = logits*thres
P = F.gumbel_softmax(logits)
discrimn_loss = self.compute_discrimn_loss(Z.T)
compress_loss = self.gumb_compress_loss(Z, P)
total_loss = discrimn_loss - compress_loss
return -total_loss, (discrimn_loss, compress_loss)
================================================
FILE: model/model.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import resnet18, resnet34, resnet50
from .resnet import Resnet10CIFAR
def getmodel(arch):
#backbone = resnet18()
if arch == "resnet18-cifar":
backbone = resnet18()
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
backbone.fc = nn.Identity()
return backbone, 512
elif arch == "resnet18-imagenet":
backbone = resnet18()
backbone.fc = nn.Identity()
return backbone, 512
elif arch == "resnet18-tinyimagenet":
backbone = resnet18()
backbone.avgpool = nn.AdaptiveAvgPool2d(1)
backbone.fc = nn.Identity()
return backbone, 512
else:
raise NameError("{} not found in network architecture".format(arch))
class encoder(nn.Module):
def __init__(self,z_dim=1024,hidden_dim=4096, norm_p=2, arch = "resnet18-cifar"):
super().__init__()
backbone, feature_dim = getmodel(arch)
self.backbone = backbone
self.norm_p = norm_p
self.pre_feature = nn.Sequential(nn.Linear(feature_dim,hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU()
)
self.projection = nn.Sequential(nn.Linear(hidden_dim,hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim,z_dim))
def forward(self, x, is_test = False):
feature = self.backbone(x)
feature = self.pre_feature(feature)
z = F.normalize(self.projection(feature),p=self.norm_p)
if is_test:
return z, feature
else:
return z
================================================
FILE: model/resnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, resnet34, resnet50
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes,
kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, blocks_config, first_config, first_pool=False):
super(ResNet, self).__init__()
#format of first_config
[in_chan, chan, k, s] = first_config
self.in_planes = chan
self.conv1 = nn.Conv2d(in_chan, chan, kernel_size=k, stride=s,
padding=k//2, bias=False)
self.bn1 = nn.BatchNorm2d(chan)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if first_pool else nn.Identity()
self.layer1 = self._make_layer(block, blocks_config[0][0], blocks_config[0][1], stride=1)
self.layer2 = self._make_layer(block, blocks_config[1][0], blocks_config[1][1], stride=2)
self.layer3 = self._make_layer(block, blocks_config[2][0], blocks_config[2][1], stride=2)
self.layer4 = self._make_layer(block, blocks_config[3][0], blocks_config[3][1], stride=2)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.pool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
feature = out.mean((2,3))
return feature
def Resnet10MNIST():
block = BasicBlock
blocks_config = [
[64,1],[128,1],[256,1],[512,1]
]
first_config = [1,64,3,1]
return ResNet(block,blocks_config,first_config,first_pool=False)
def Resnet10CIFAR():
block = BasicBlock
blocks_config = [
[32,1],[64,1],[128,1],[256,1]
]
first_config = [3,32,3,1]
return ResNet(block,blocks_config,first_config,first_pool=True)
def Resnet18imgs():
block = BasicBlock
blocks_config = [
[32,2],[64,2],[128,2],[256,2]
]
first_config = [1,32,5,2]
return ResNet(block,blocks_config,first_config,first_pool=True)
def Resnet18CIFAR():
backbone = resnet18()
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
backbone.fc = nn.Identity()
return backbone
def Resnet18STL10():
block = BasicBlock
blocks_config = [
[64,2],[128,2],[256,2],[512,2]
]
first_config = [3,64,5,2]
return ResNet(block,blocks_config,first_config,first_pool=True)
def Resnet34CIFAR():
backbone = resnet34()
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
backbone.maxpool = nn.Identity()
backbone.fc = nn.Identity()
return backbone
def Resnet34STL10():
block = BasicBlock
blocks_config = [
[64,3],[128,4],[256,6],[512,3]
]
first_config = [3,64,5,2]
return ResNet(block,blocks_config,first_config,first_pool=True)
================================================
FILE: requirements.text
================================================
torch
torchvision
torchmetrics
numpy
tqdm
Pillow
matplotlib
scikit-learn
gitextract_sphqay9s/ ├── README.md ├── dataset/ │ ├── aug.py │ ├── aug4img.py │ └── datasets.py ├── evaluate.py ├── func.py ├── lars.py ├── loss.py ├── main.py ├── mcr/ │ └── loss.py ├── model/ │ ├── model.py │ └── resnet.py └── requirements.text
SYMBOL INDEX (122 symbols across 11 files)
FILE: dataset/aug.py
function load_transforms (line 9) | def load_transforms(name):
class Solarization (line 77) | class Solarization:
method __call__ (line 80) | def __call__(self, img: Image) -> Image:
class GBlur (line 92) | class GBlur(object):
method __init__ (line 93) | def __init__(self, p):
method __call__ (line 96) | def __call__(self, img):
class AddGaussianNoise (line 104) | class AddGaussianNoise(object):
method __init__ (line 105) | def __init__(self, mean=0., std=1.):
method __call__ (line 109) | def __call__(self, tensor):
method __repr__ (line 112) | def __repr__(self):
class ContrastiveLearningViewGenerator (line 116) | class ContrastiveLearningViewGenerator(object):
method __init__ (line 117) | def __init__(self, num_patch = 4):
method __call__ (line 121) | def __call__(self, x):
FILE: dataset/aug4img.py
class Solarization (line 12) | class Solarization:
method __call__ (line 15) | def __call__(self, img: Image) -> Image:
class GBlur (line 27) | class GBlur(object):
method __init__ (line 28) | def __init__(self, p):
method __call__ (line 31) | def __call__(self, img):
class AddGaussianNoise (line 39) | class AddGaussianNoise(object):
method __init__ (line 40) | def __init__(self, mean=0., std=1.):
method __call__ (line 44) | def __call__(self, tensor):
method __repr__ (line 47) | def __repr__(self):
class ContrastiveLearningViewGenerator (line 51) | class ContrastiveLearningViewGenerator(object):
method __init__ (line 52) | def __init__(self, num_patch = 4):
method __call__ (line 56) | def __call__(self, x):
FILE: dataset/datasets.py
function load_dataset (line 5) | def load_dataset(data_name, train=True, num_patch = 4, path="./data/"):
function sparse2coarse (line 48) | def sparse2coarse(targets):
FILE: evaluate.py
function compute_accuracy (line 60) | def compute_accuracy(y_pred, y_true):
function chunk_avg (line 68) | def chunk_avg(x,n_chunks=2,normalize=False):
function test (line 77) | def test(net, train_loader, test_loader):
function chunk_avg (line 136) | def chunk_avg(x,n_chunks=2,normalize=False):
FILE: func.py
class WeightedKNNClassifier (line 24) | class WeightedKNNClassifier(Metric):
method __init__ (line 25) | def __init__(
method update (line 62) | def update(
method set_tk (line 90) | def set_tk(self, T, k):
method compute (line 95) | def compute(self) -> Tuple[float]:
function linear (line 181) | def linear(train_features, train_labels, test_features, test_labels, lr=...
function accuracy (line 240) | def accuracy(output, target, topk=(1,)):
class tensor_dataset (line 256) | class tensor_dataset(data.Dataset):
method __init__ (line 257) | def __init__(self,x,y):
method __getitem__ (line 262) | def __getitem__(self,indx):
method __len__ (line 265) | def __len__(self):
function set_gamma (line 271) | def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0...
function warmup_lr (line 280) | def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10):
function marginal_H (line 285) | def marginal_H(logits):
function chunk_avg (line 292) | def chunk_avg(x,n_chunks=2,normalize=False):
function cluster_match (line 300) | def cluster_match(cluster_mtx,label_mtx,n_classes=10,print_result=True):
function cluster_merge_match (line 330) | def cluster_merge_match(cluster_mtx,label_mtx,print_result=True):
function cluster_acc (line 344) | def cluster_acc(test_loader,net,device,print_result=False,save_name_img=...
function save_cluster_imgs (line 378) | def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100):
function save_latent_pca_figure (line 399) | def save_latent_pca_figure(z_mtx,cluster_mtx,save_name):
function analyze_latent (line 418) | def analyze_latent(z_mtx,cluster_mtx):
FILE: lars.py
class LARS (line 5) | class LARS(Optimizer):
method __init__ (line 26) | def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov...
method step (line 48) | def step(self, epoch=None, closure=None):
class LARSWrapper (line 146) | class LARSWrapper:
method __init__ (line 147) | def __init__(
method defaults (line 184) | def defaults(self):
method defaults (line 188) | def defaults(self, defaults):
method __class__ (line 192) | def __class__(self):
method state (line 196) | def state(self):
method state (line 200) | def state(self, state):
method param_groups (line 204) | def param_groups(self):
method param_groups (line 208) | def param_groups(self, value):
method step (line 212) | def step(self, closure=None):
method update_p (line 234) | def update_p(self, p, group, weight_decay):
FILE: loss.py
class contrastive_loss (line 5) | class contrastive_loss(nn.Module):
method __init__ (line 6) | def __init__(self):
method forward (line 9) | def forward(self,x,labels):
class SimCLR (line 15) | class SimCLR(nn.Module):
method __init__ (line 16) | def __init__(self,temperature=0.5,n_views=2,contrastive=False):
method info_nce_loss (line 26) | def info_nce_loss(self,X):
method forward (line 60) | def forward(self,X):
class Z_loss (line 65) | class Z_loss(nn.Module):
method __init__ (line 66) | def __init__(self,):
method forward (line 70) | def forward(self,z):
class TotalCodingRate (line 76) | class TotalCodingRate(nn.Module):
method __init__ (line 77) | def __init__(self, eps=0.01):
method compute_discrimn_loss (line 81) | def compute_discrimn_loss(self, W):
method forward (line 89) | def forward(self,X):
class MaximalCodingRateReduction (line 92) | class MaximalCodingRateReduction(torch.nn.Module):
method __init__ (line 93) | def __init__(self, eps=0.01, gamma=1):
method compute_discrimn_loss (line 98) | def compute_discrimn_loss(self, W):
method compute_compress_loss (line 106) | def compute_compress_loss(self, W, Pi):
method forward (line 118) | def forward(self, X, Y, num_classes=None):
FILE: main.py
function chunk_avg (line 67) | def chunk_avg(x,n_chunks=2,normalize=False):
class Similarity_Loss (line 76) | class Similarity_Loss(nn.Module):
method __init__ (line 77) | def __init__(self, ):
method forward (line 81) | def forward(self, z_list, z_avg):
function cal_TCR (line 96) | def cal_TCR(z, criterion, num_patches):
function main (line 146) | def main():
FILE: mcr/loss.py
class MCRGANloss (line 6) | class MCRGANloss(nn.Module):
method __init__ (line 8) | def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000...
method forward (line 18) | def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop):
method old_version (line 25) | def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_...
method debug (line 63) | def debug(self, Z, Z_bar, real_label):
method compute_discrimn_loss (line 67) | def compute_discrimn_loss(self, Z):
method compute_compress_loss (line 75) | def compute_compress_loss(self, Z, Pi):
method deltaR (line 90) | def deltaR(self, Z, Y, num_classes):
method gumb_compress_loss (line 108) | def gumb_compress_loss(self, Z, P):
method pseudo_label_loss (line 123) | def pseudo_label_loss(self, Z, logits, thres = 1.4):
FILE: model/model.py
function getmodel (line 9) | def getmodel(arch):
class encoder (line 32) | class encoder(nn.Module):
method __init__ (line 33) | def __init__(self,z_dim=1024,hidden_dim=4096, norm_p=2, arch = "resnet...
method forward (line 46) | def forward(self, x, is_test = False):
FILE: model/resnet.py
class BasicBlock (line 7) | class BasicBlock(nn.Module):
method __init__ (line 9) | def __init__(self, in_planes, planes, stride=1):
method forward (line 26) | def forward(self, x):
class Bottleneck (line 33) | class Bottleneck(nn.Module):
method __init__ (line 36) | def __init__(self, in_planes, planes, stride=1):
method forward (line 55) | def forward(self, x):
class ResNet (line 64) | class ResNet(nn.Module):
method __init__ (line 65) | def __init__(self, block, blocks_config, first_config, first_pool=False):
method _make_layer (line 80) | def _make_layer(self, block, planes, num_blocks, stride):
method forward (line 88) | def forward(self, x):
function Resnet10MNIST (line 103) | def Resnet10MNIST():
function Resnet10CIFAR (line 111) | def Resnet10CIFAR():
function Resnet18imgs (line 119) | def Resnet18imgs():
function Resnet18CIFAR (line 127) | def Resnet18CIFAR():
function Resnet18STL10 (line 134) | def Resnet18STL10():
function Resnet34CIFAR (line 142) | def Resnet34CIFAR():
function Resnet34STL10 (line 149) | def Resnet34STL10():
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (71K chars).
[
{
"path": "README.md",
"chars": 5633,
"preview": "# EMP-SSL: Towards Self-Supervised Learning in One Training Epoch\n\n[:"
},
{
"path": "evaluate.py",
"chars": 5297,
"preview": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\nfro"
},
{
"path": "func.py",
"chars": 16433,
"preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics imp"
},
{
"path": "lars.py",
"chars": 8346,
"preview": "import torch\nimport torch.optim as optim\nfrom torch.optim.optimizer import Optimizer, required\n\nclass LARS(Optimizer):\n "
},
{
"path": "loss.py",
"chars": 4898,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass contrastive_loss(nn.Module):\n def __init__("
},
{
"path": "main.py",
"chars": 6056,
"preview": "############\n## Import ##\n############\nimport argparse\nimport torch.nn as nn\nimport torch.optim as optim\nimport os\nfrom "
},
{
"path": "mcr/loss.py",
"chars": 4464,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MCRGANloss(nn.Module):\n\n def __init__(self"
},
{
"path": "model/model.py",
"chars": 1846,
"preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nfrom torchvision.models import resnet18, resnet34, r"
},
{
"path": "model/resnet.py",
"chars": 5439,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision.models import resnet18, resnet34, r"
},
{
"path": "requirements.text",
"chars": 72,
"preview": "torch\ntorchvision\ntorchmetrics\nnumpy\ntqdm\nPillow\nmatplotlib\nscikit-learn"
}
]
About this extraction
This page contains the full source code of the tsb0601/EMP-SSL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (66.6 KB), approximately 17.3k tokens, and a symbol index with 122 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.