Repository: tsb0601/EMP-SSL Branch: main Commit: 8fc3758617e0 Files: 13 Total size: 66.6 KB Directory structure: gitextract_sphqay9s/ ├── README.md ├── dataset/ │ ├── aug.py │ ├── aug4img.py │ └── datasets.py ├── evaluate.py ├── func.py ├── lars.py ├── loss.py ├── main.py ├── mcr/ │ └── loss.py ├── model/ │ ├── model.py │ └── resnet.py └── requirements.text ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # EMP-SSL: Towards Self-Supervised Learning in One Training Epoch [![arXiv](https://img.shields.io/badge/arXiv-2304.03977-b31b1b.svg)](https://arxiv.org/abs/2304.03977) ![Training Pipeline](pipeline.png) Authors: Shengbang Tong*, Yubei Chen*, Yi Ma, Yann LeCun ## Introduction This repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch." The paper introduces a simplistic but efficient self-supervised learning method called Extreme-Multi-Patch Self-Supervised-Learning (EMP-SSL). EMP-SSL significantly reduces the training epochs required for convergence by increasing the number of fix size image patches from each image instance. ## Preparing Training Data Cifar10 and Cifar100 can be downloaded automatically in the script. ImageNet100 is a special subset of ImageNet. Details can be found in this [link](https://github.com/HobbitLong/CMC/issues/21). ## Getting Started Current code implementation supports Cifar10, Cifar100 and ImageNet100. To get started with the EMP-SSL implementation, follow these instructions: ### 1. Clone this repository ```bash git clone https://github.com/tsb0601/emp-ssl.git cd emp-ssl ``` ### 2. Install required packages ``` pip install -r requirements.txt ``` ### 3. Training #### Reproducing 1-epoch results | | CIFAR-10
1 Epoch | CIFAR-100
1 Epoch | Tiny ImageNet
1 epochs | ImageNet-100
1 epochs | |--------------------|:----------------------:|:-----------------------:|:----------------------------:|:--------------------------:| | EMP-SSL (1 Epoch) | 0.842 | 0.585 | 0.381 | 0.585 | For CIFAR10 or CIFAR100 ``` python main.py --data cifar10 --epoch 2 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3 ``` For ImageNet100 ``` python main.py --data imagenet100 --epoch 2 --patch_sim 200 --arch 'resnet18-imagenet' --num_patches 20 --lr 0.3 ``` #### Reproducing multi epochs results | | CIFAR-10
1 Epoch | CIFAR-10
10 Epochs | CIFAR-10
30 Epochs | CIFAR-10
1000 Epochs | CIFAR-100
1 Epoch | CIFAR-100
10 Epochs | CIFAR-100
30 Epochs | CIFAR-100
1000 Epochs | Tiny ImageNet
10 Epochs | Tiny ImageNet
1000 Epochs |ImageNet-100
10 Epochs | ImageNet-100
400 Epochs | |----------------------|:-------------------:|:---------------------:|:---------------------:|:-----------------------:|:--------------------:|:----------------------:|:----------------------:|:------------------------:| :------------------------:|:------------------------:|:------------------------:| :------------------------:| | SimCLR | 0.282 | 0.565 | 0.663 | 0.910 | 0.054 | 0.185 | 0.341 | 0.662 | - | 0.488 | - | 0.776 | BYOL | 0.249 | 0.489 | 0.684 | 0.926 | 0.043 | 0.150 | 0.349 | 0.708 | - | 0.510 | - | 0.802 | VICReg | 0.406 | 0.697 | 0.781 | 0.921 | 0.079 | 0.319 | 0.479 | 0.685 | - | - | - | 0.792 | SwAV | 0.245 | 0.532 | 0.767 | 0.923 | 0.028 | 0.208 | 0.294 | 0.658 |- | - | - | 0.740 | ReSSL | 0.245 | 0.256 | 0.525 | 0.914 | 0.033 | 0.122 | 0.247 | 0.674 |- | - | - | 0.769 | EMP-SSL (20 patches) | 0.806 | 0.907 | 0.931 | - | 0.551 | 0.678 | 0.724 | - | - | - | - | - | EMP-SSL (200 patches)| 0.826* | 0.915 | 0.934 | - | 0.577 | 0.701 | 0.733 | - | 0.515 | - | 0.789 | - \* Here, we change learning rate schedule to decay in 30 epochs, so 1 epoch accuracy will be slightly lower than optimizing for 1-epoch training. Change num_patches here to change the number of patches used in EMP-SSL training. ``` python main.py --data cifar10 --epoch 30 --patch_sim 200 --arch 'resnet18-cifar' --num_patches 20 --lr 0.3 ``` ### 4. Evaluating Because our model is trained with only fixed size image patches. To evaluate the performance, we adopt bag-of-features model from intra-instance VICReg paper. Change test_patches here to adjust number of patches used in bag-of-feature model for different GPUs. ``` python evaluate.py --model_path 'path to your evaluated model' --test_patches 128 ``` ## Acknowledgment This repo is inspired by [MCR2](https://github.com/Ma-Lab-Berkeley/MCR2), [solo-learn](https://github.com/vturrisi/solo-learn) and [NMCE](https://github.com/zengyi-li/NMCE-release) repo. ## Citation If you find this repository useful, please consider giving a star :star: and citation: ``` @article{tong2023empssl, title={EMP-SSL: Towards Self-Supervised Learning in One Training Epoch}, author={Shengbang Tong and Yubei Chen and Yi Ma and Yann Lecun}, journal={arXiv preprint arXiv:2304.03977}, year={2023} } ``` ================================================ FILE: dataset/aug.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torchvision.transforms as transforms from PIL import Image, ImageFilter, ImageOps def load_transforms(name): """Load data transformations. Note: - Gaussian Blur is defined at the bottom of this file. """ _name = name.lower() if _name == "cifar_sup": normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) aug_transform = transforms.Compose([ transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)), transforms.RandomCrop(32, padding=8), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) baseline_transform = transforms.Compose([ transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)), transforms.ToTensor(),normalize]) elif _name == "cifar_patch": normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) aug_transform = transforms.Compose([ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), transforms.ToTensor(), normalize ]) baseline_transform = transforms.Compose([ transforms.ToTensor(), normalize]) elif _name == "cifar_simclr_norm": normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) aug_transform = transforms.Compose([ transforms.RandomResizedCrop(32,scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.ToTensor(), normalize ]) baseline_transform = transforms.Compose([ transforms.ToTensor(),normalize]) elif _name == "cifar_byol": normalize = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) aug_transform = transforms.Compose([ transforms.RandomResizedCrop( (32, 32), scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC, ), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([Solarization()], p=0.1), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), normalize ]) baseline_transform = transforms.Compose([ # transforms.RandomResizedCrop(32,scale=(0.765625, 0.765625),ratio=(1., 1.)), transforms.ToTensor(),normalize]) else: raise NameError("{} not found in transform loader".format(name)) return aug_transform, baseline_transform class Solarization: """Solarization as a callable object.""" def __call__(self, img: Image) -> Image: """Applies solarization to an input image. Args: img (Image): an image in the PIL.Image format. Returns: Image: a solarized image. """ return ImageOps.solarize(img) class GBlur(object): def __init__(self, p): self.p = p def __call__(self, img): if np.random.rand() < self.p: sigma = np.random.rand() * 1.9 + 0.1 return img.filter(ImageFilter.GaussianBlur(sigma)) else: return img class AddGaussianNoise(object): def __init__(self, mean=0., std=1.): self.std = std self.mean = mean def __call__(self, tensor): return tensor + torch.randn(tensor.size()) * self.std + self.mean def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) class ContrastiveLearningViewGenerator(object): def __init__(self, num_patch = 4): self.num_patch = num_patch def __call__(self, x): normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) aug_transform = transforms.Compose([ transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8), transforms.RandomGrayscale(p=0.2), GBlur(p=0.1), transforms.RandomApply([Solarization()], p=0.1), transforms.ToTensor(), normalize ]) augmented_x = [aug_transform(x) for i in range(self.num_patch)] return augmented_x ================================================ FILE: dataset/aug4img.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torchvision.transforms as transforms from PIL import Image, ImageFilter, ImageOps from torchvision.transforms import InterpolationMode class Solarization: """Solarization as a callable object.""" def __call__(self, img: Image) -> Image: """Applies solarization to an input image. Args: img (Image): an image in the PIL.Image format. Returns: Image: a solarized image. """ return ImageOps.solarize(img) class GBlur(object): def __init__(self, p): self.p = p def __call__(self, img): if np.random.rand() < self.p: sigma = np.random.rand() * 1.9 + 0.1 return img.filter(ImageFilter.GaussianBlur(sigma)) else: return img class AddGaussianNoise(object): def __init__(self, mean=0., std=1.): self.std = std self.mean = mean def __call__(self, tensor): return tensor + torch.randn(tensor.size()) * self.std + self.mean def __repr__(self): return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) class ContrastiveLearningViewGenerator(object): def __init__(self, num_patch = 4): self.num_patch = num_patch def __call__(self, x): aug_transform = transforms.Compose([ transforms.RandomResizedCrop( 224, scale=(0.25, 0.25), interpolation=InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8), transforms.RandomGrayscale(p=0.2), GBlur(p=0.1), transforms.RandomApply([Solarization()], p=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) augmented_x = [aug_transform(x) for i in range(self.num_patch)] return augmented_x ================================================ FILE: dataset/datasets.py ================================================ import os import numpy as np import torchvision def load_dataset(data_name, train=True, num_patch = 4, path="./data/"): """Loads a dataset for training and testing. If augmentloader is used, transform should be None. Parameters: data_name (str): name of the dataset transform_name (torchvision.transform): name of transform to be applied (see aug.py) use_baseline (bool): use baseline transform or augmentation transform train (bool): load training set or not contrastive (bool): whether to convert transform to multiview augmentation for contrastive learning. n_views (bool): number of views for contrastive learning path (str): path to dataset base path Returns: dataset (torch.data.dataset) """ _name = data_name.lower() if _name == "imagenet": from .aug4img import ContrastiveLearningViewGenerator else: from .aug import ContrastiveLearningViewGenerator transform = ContrastiveLearningViewGenerator(num_patch = num_patch) if _name == "cifar10": trainset = torchvision.datasets.CIFAR10(root=os.path.join(path, "CIFAR10"), train=train, download=True, transform=transform) trainset.num_classes = 10 elif _name == "cifar100": trainset = torchvision.datasets.CIFAR100(root=os.path.join(path, "CIFAR100"), train=train, download=True, transform=transform) trainset.num_classes = 100 elif _name == "imagenet": if train: trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/train100/",transform=transform) #trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/train/",transform=transform) else: trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/ILSVRC2012/val100/",transform=transform) #trainset = torchvision.datasets.ImageFolder(root="/home/peter/Data/tiny-imagenet-200/val/",transform=transform) trainset.num_classes = 200 else: raise NameError("{} not found in trainset loader".format(_name)) return trainset def sparse2coarse(targets): """CIFAR100 Coarse Labels. """ coarse_targets = [ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13] return np.array(coarse_targets)[targets] ================================================ FILE: evaluate.py ================================================ ############ ## Import ## ############ import argparse import torch.nn as nn from torch.utils.data import DataLoader from model.model import encoder from dataset.datasets import load_dataset import numpy as np import torch.nn.functional as F from tqdm import tqdm import torch import numpy as np from func import WeightedKNNClassifier, linear ###################### ## Parsing Argument ## ###################### import argparse parser = argparse.ArgumentParser(description='Evaluation') parser.add_argument('--test_patches', type=int, default=128, help='number of patches used in testing (default: 128)') parser.add_argument('--data', type=str, default="cifar10", help='dataset (default: cifar10)') parser.add_argument('--arch', type=str, default="resnet18-cifar", help='network architecture (default: resnet18-cifar)') parser.add_argument('--lr', type=float, default=0.03, help='learning rate for linear eval (default: 0.03)') parser.add_argument('--linear', type=bool, default=True, help='use linear eval or not') parser.add_argument('--knn', help='evaluate using kNN measuring cosine similarity', action='store_true') parser.add_argument('--model_path', type=str, default="", help='model directory for eval') args = parser.parse_args() ###################### ## Testing Accuracy ## ###################### test_patches = args.test_patches def compute_accuracy(y_pred, y_true): """Compute accuracy by counting correct classification. """ assert y_pred.shape == y_true.shape return 1 - np.count_nonzero(y_pred - y_true) / y_true.size knn_classifier = WeightedKNNClassifier() def chunk_avg(x,n_chunks=2,normalize=False): x_list = x.chunk(n_chunks,dim=0) x = torch.stack(x_list,dim=0) if not normalize: return x.mean(0) else: return F.normalize(x.mean(0),dim=1) def test(net, train_loader, test_loader): train_z_full_list, train_y_list, test_z_full_list, test_y_list = [], [], [], [] with torch.no_grad(): for x, y in tqdm(train_loader): x = torch.cat(x, dim = 0) z_proj, z_pre = net(x, is_test=True) z_pre = chunk_avg(z_pre, test_patches) z_pre = z_pre.detach().cpu() train_z_full_list.append(z_pre) knn_classifier.update(train_features = z_pre, train_targets = y) train_y_list.append(y) for x, y in tqdm(test_loader): x = torch.cat(x, dim = 0) z_proj, z_pre = net(x, is_test=True) z_pre = chunk_avg(z_pre, test_patches) z_pre = z_pre.detach().cpu() test_z_full_list.append(z_pre) knn_classifier.update(test_features = z_pre, test_targets = y) test_y_list.append(y) train_features_full, train_labels, test_features_full, test_labels = torch.cat(train_z_full_list,dim=0), torch.cat(train_y_list,dim=0), torch.cat(test_z_full_list,dim=0), torch.cat(test_y_list,dim=0) if args.data == "cifar10": num_classes = 10 elif args.data == "cifar100": num_classes = 100 elif args.data == "tinyimagenet200": num_classes = 200 elif args.data == "imagenet100": num_classes = 100 elif args.data == "imagenet": num_classes = 1000 if args.linear: print("Using Linear Eval to evaluate accuracy") linear(train_features_full, train_labels, test_features_full, test_labels, lr=args.lr, num_classes = num_classes) if args.knn: print("Using KNN to evaluate accuracy") top1, top5 = knn_classifier.compute() print("KNN (top1/top5):", top1, top5) def chunk_avg(x,n_chunks=2,normalize=False): x_list = x.chunk(n_chunks,dim=0) x = torch.stack(x_list,dim=0) if not normalize: return x.mean(0) else: return F.normalize(x.mean(0),dim=1) torch.multiprocessing.set_sharing_strategy('file_system') #Get Dataset if args.data == "imagenet100" or args.data == "imagenet": memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches) memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8) test_data = load_dataset(args.data, train=False, num_patch = test_patches) test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8) else: memory_dataset = load_dataset(args.data, train=True, num_patch = test_patches) memory_loader = DataLoader(memory_dataset, batch_size=50, shuffle=True, drop_last=True,num_workers=8) test_data = load_dataset(args.data, train=False, num_patch = test_patches) test_loader = DataLoader(test_data, batch_size=50, shuffle=True, num_workers=8) # Load Model and Checkpoint use_cuda = True device = torch.device("cuda" if use_cuda else "cpu") net = encoder(arch = args.arch) net = nn.DataParallel(net) save_dict = torch.load(args.model_path) net.load_state_dict(save_dict,strict=False) net.cuda() net.eval() test(net, memory_loader, test_loader) ================================================ FILE: func.py ================================================ import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score import torchvision # import torch.nn from torch import nn, optim import torch.nn as nn from torch.utils import data from torch.utils.data import DataLoader from typing import Tuple import torch.nn.functional as F from torchmetrics.metric import Metric class WeightedKNNClassifier(Metric): def __init__( self, k: int = 20, T: float = 0.07, max_distance_matrix_size: int = int(5e6), distance_fx: str = "cosine", epsilon: float = 0.00001, dist_sync_on_step: bool = False, ): """Implements the weighted k-NN classifier used for evaluation. Args: k (int, optional): number of neighbors. Defaults to 20. T (float, optional): temperature for the exponential. Only used with cosine distance. Defaults to 0.07. max_distance_matrix_size (int, optional): maximum number of elements in the distance matrix. Defaults to 5e6. distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or "euclidean". Defaults to "cosine". epsilon (float, optional): Small value for numerical stability. Only used with euclidean distance. Defaults to 0.00001. dist_sync_on_step (bool, optional): whether to sync distributed values at every step. Defaults to False. """ super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) self.k = k self.T = T self.max_distance_matrix_size = max_distance_matrix_size self.distance_fx = distance_fx self.epsilon = epsilon self.add_state("train_features", default=[], persistent=False) self.add_state("train_targets", default=[], persistent=False) self.add_state("test_features", default=[], persistent=False) self.add_state("test_targets", default=[], persistent=False) def update( self, train_features: torch.Tensor = None, train_targets: torch.Tensor = None, test_features: torch.Tensor = None, test_targets: torch.Tensor = None, ): """Updates the memory banks. If train (test) features are passed as input, the corresponding train (test) targets must be passed as well. Args: train_features (torch.Tensor, optional): a batch of train features. Defaults to None. train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None. test_features (torch.Tensor, optional): a batch of test features. Defaults to None. test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None. """ assert (train_features is None) == (train_targets is None) assert (test_features is None) == (test_targets is None) if train_features is not None: assert train_features.size(0) == train_targets.size(0) self.train_features.append(train_features.detach()) self.train_targets.append(train_targets.detach()) if test_features is not None: assert test_features.size(0) == test_targets.size(0) self.test_features.append(test_features.detach()) self.test_targets.append(test_targets.detach()) def set_tk(self, T, k): self.T = T self.k = k @torch.no_grad() def compute(self) -> Tuple[float]: """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected, the weight is computed using the exponential of the temperature scaled cosine distance of the samples. If euclidean distance is selected, the weight corresponds to the inverse of the euclidean distance. Returns: Tuple[float]: k-NN accuracy @1 and @5. """ #print(self.T, self.k) train_features = torch.cat(self.train_features) train_targets = torch.cat(self.train_targets) test_features = torch.cat(self.test_features) test_targets = torch.cat(self.test_targets) if self.distance_fx == "cosine": train_features = F.normalize(train_features) test_features = F.normalize(test_features) num_classes = torch.unique(test_targets).numel() num_train_images = train_targets.size(0) num_test_images = test_targets.size(0) num_train_images = train_targets.size(0) chunk_size = min( max(1, self.max_distance_matrix_size // num_train_images), num_test_images, ) k = min(self.k, num_train_images) top1, top5, total = 0.0, 0.0, 0 retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) for idx in range(0, num_test_images, chunk_size): # get the features for test images features = test_features[idx : min((idx + chunk_size), num_test_images), :] targets = test_targets[idx : min((idx + chunk_size), num_test_images)] batch_size = targets.size(0) # calculate the dot product and compute top-k neighbors if self.distance_fx == "cosine": similarities = torch.mm(features, train_features.t()) elif self.distance_fx == "euclidean": similarities = 1 / (torch.cdist(features, train_features) + self.epsilon) else: raise NotImplementedError similarities, indices = similarities.topk(k, largest=True, sorted=True) candidates = train_targets.view(1, -1).expand(batch_size, -1) retrieved_neighbors = torch.gather(candidates, 1, indices) retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) if self.distance_fx == "cosine": similarities = similarities.clone().div_(self.T).exp_() probs = torch.sum( torch.mul( retrieval_one_hot.view(batch_size, -1, num_classes), similarities.view(batch_size, -1, 1), ), 1, ) _, predictions = probs.sort(1, True) # find the predictions that match the target correct = predictions.eq(targets.data.view(-1, 1)) top1 = top1 + correct.narrow(1, 0, 1).sum().item() top5 = ( top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item() ) # top5 does not make sense if k < 5 total += targets.size(0) top1 = top1 * 100.0 / total top5 = top5 * 100.0 / total self.reset() return top1, top5 def linear(train_features, train_labels, test_features, test_labels, lr=0.0075, num_classes = 100): train_data = tensor_dataset(train_features,train_labels) test_data = tensor_dataset(test_features,test_labels) train_loader = DataLoader(train_data, batch_size=100, shuffle=True, drop_last=True, num_workers=2) test_loader = DataLoader(test_data, batch_size=100, shuffle=True, drop_last=False, num_workers=2) LL = nn.Linear(train_features.shape[1],num_classes) optimizer = torch.optim.SGD(LL.parameters(), lr=lr, momentum=0.9, weight_decay=5e-5) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100) criterion = torch.nn.CrossEntropyLoss() test_acc_list = [] for epoch in range(100): top1_train_accuracy = 0 for counter, (x_batch, y_batch) in enumerate(train_loader): x_batch = x_batch y_batch = y_batch logits = LL(x_batch) loss = criterion(logits, y_batch) top1 = accuracy(logits, y_batch, topk=(1,)) top1_train_accuracy += top1[0] optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() top1_train_accuracy /= (counter + 1) top1_accuracy = 0 top5_accuracy = 0 for counter, (x_batch, y_batch) in enumerate(test_loader): x_batch = x_batch y_batch = y_batch logits = LL(x_batch) top1, top5 = accuracy(logits, y_batch, topk=(1,5)) top1_accuracy += top1[0] top5_accuracy += top5[0] top1_accuracy /= (counter + 1) top5_accuracy /= (counter + 1) test_acc_list.append(top1_accuracy) print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}") acc_vect = torch.tensor(test_acc_list) print('best linear test acc {}, last acc {}'.format(acc_vect.max().item(),acc_vect[-1].item())) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res class tensor_dataset(data.Dataset): def __init__(self,x,y): self.x = x self.y = y self.length = x.shape[0] def __getitem__(self,indx): return self.x[indx], self.y[indx] def __len__(self): return self.length def set_gamma(loss_fn,epoch,total_epoch=500,warmup_epoch=100,gamma_min=0.,gamma_max=1.0): warmup_start = total_epoch - warmup_epoch warmup_end = total_epoch if warmup_start < epoch<=warmup_end: loss_fn.gamma = ((epoch - warmup_start)/(warmup_end - warmup_start))*(gamma_max - gamma_min) + gamma_min else: loss_fn.gamma = gamma_min def warmup_lr(optimizer,epoch,base_lr,warmup_epoch=10): if epoch0: max_label_list = [] max_count_list = [] for indx in cluster_indx: #calculate highest number of matchs mask = cluster_mtx==indx label_elements, counts = label_mtx[mask].unique(return_counts=True) for assigned_label in assigned_label_list: counts[label_elements==assigned_label] = 0 max_count_list.append(counts.max()) max_label_list.append(label_elements[counts.argmax()]) max_label = torch.stack(max_label_list) max_count = torch.stack(max_count_list) assigned_label_list.append(max_label[max_count.argmax()]) assigned_count.append(max_count.max()) cluster_indx.pop(max_count.argmax()) total_correct = torch.tensor(assigned_count).sum().item() total_sample = cluster_mtx.shape[0] acc = total_correct/total_sample if print_result: print('{}/{} ({}%) correct'.format(total_correct,total_sample,acc*100)) else: return total_correct, total_sample, acc def cluster_merge_match(cluster_mtx,label_mtx,print_result=True): cluster_indx = list(cluster_mtx.unique()) n_correct = 0 for cluster_id in cluster_indx: label_elements, counts = label_mtx[cluster_mtx==cluster_id].unique(return_counts=True) n_correct += counts.max() total_sample = len(cluster_mtx) acc = n_correct.item()/total_sample if print_result: print('{}/{} ({}%) correct'.format(n_correct,total_sample,acc*100)) else: return n_correct, total_sample, acc def cluster_acc(test_loader,net,device,print_result=False,save_name_img='cluster_img',save_name_fig='pca_figure'): cluster_list = [] label_list = [] x_list = [] z_list = [] net.eval() for x, y in test_loader: with torch.no_grad(): x, y = x.float().to(device), y.to(device) z, logit = net(x) if logit.sum() == 0: logit += torch.randn_like(logit) cluster_list.append(logit.max(dim=1)[1].cpu()) label_list.append(y.cpu()) x_list.append(x.cpu()) z_list.append(z.cpu()) net.train() cluster_mtx = torch.cat(cluster_list,dim=0) label_mtx = torch.cat(label_list,dim=0) x_mtx = torch.cat(x_list,dim=0) z_mtx = torch.cat(z_list,dim=0) _, _, acc_single = cluster_match(cluster_mtx,label_mtx,n_classes=label_mtx.max()+1,print_result=False) _, _, acc_merge = cluster_merge_match(cluster_mtx,label_mtx,print_result=False) NMI = normalized_mutual_info_score(label_mtx.numpy(),cluster_mtx.numpy()) ARI = adjusted_rand_score(label_mtx.numpy(),cluster_mtx.numpy()) if print_result: print('cluster match acc {}, cluster merge match acc {}, NMI {}, ARI {}'.format(acc_single,acc_merge,NMI,ARI)) save_name_img += '_acc'+ str(acc_single)[2:5] save_cluster_imgs(cluster_mtx,x_mtx,save_name_img) save_latent_pca_figure(z_mtx,cluster_mtx,save_name_fig) return acc_single, acc_merge, NMI, ARI def save_cluster_imgs(cluster_mtx,x_mtx,save_name,npercluster=100): cluster_indexs, counts = cluster_mtx.unique(return_counts=True) x_list = [] counts_list = [] for i, c_indx in enumerate(cluster_indexs): if counts[i]>npercluster: x_list.append(x_mtx[cluster_mtx==c_indx,:,:,:]) counts_list.append(counts[i]) n_clusters = len(counts_list) fig, ax = plt.subplots(n_clusters,1,dpi=80,figsize=(1.2*n_clusters, 3*n_clusters)) for i, ax in enumerate(ax): img = torchvision.utils.make_grid(x_list[i][:npercluster],nrow=npercluster//5,normalize=True) ax.imshow(img.permute(1,2,0)) ax.set_axis_off() ax.set_title('Cluster with {} images'.format(counts_list[i])) fig.savefig(save_name+'.pdf') plt.close(fig) def save_latent_pca_figure(z_mtx,cluster_mtx,save_name): _, s_z_all, _ = z_mtx.svd() cluster_n = [] cluster_s = [] for cluster_indx in cluster_mtx.unique(): _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd() cluster_n.append((cluster_mtx==cluster_indx).sum().item()) cluster_s.append(s_cluster/s_cluster.max()) #make plot fig, ax = plt.subplots(1,2,figsize=(9, 3)) ax[0].plot(s_z_all) for i, s_curve in enumerate(cluster_s): ax[1].plot(s_curve,label=cluster_n[i]) ax[1].set_xlim(xmin=0,xmax=20) ax[1].legend() fig.savefig(save_name +'.pdf') plt.close(fig) def analyze_latent(z_mtx,cluster_mtx): _, s_z_all, _ = z_mtx.svd() cluster_n = [] cluster_s = [] cluster_d = [] for cluster_indx in cluster_mtx.unique(): _, s_cluster, _ = z_mtx[cluster_mtx==cluster_indx,:].svd() s_cluster = s_cluster/s_cluster.max() cluster_n.append((cluster_mtx==cluster_indx).sum().item()) cluster_s.append(s_cluster) # print(list(cluster_s)) print(s_cluster) # s_diff = s_cluster[:-1] - s_cluster[1:] # cluster_d.append(s_diff.max(0)[1]) cluster_d.append((s_cluster>0.01).sum()) for i in range(len(cluster_n)): print('subspace {}, dimension {}, samples {}'.format(i,cluster_d[i],cluster_n[i])) ================================================ FILE: lars.py ================================================ import torch import torch.optim as optim from torch.optim.optimizer import Optimizer, required class LARS(Optimizer): """ Layer-wise adaptive rate scaling - Converted from Tensorflow to Pytorch from: https://github.com/google-research/simclr/blob/master/lars_optimizer.py - Based on: https://github.com/noahgolmant/pytorch-lars params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): base learning rate (\gamma_0) lr (int): Length / Number of layers we want to apply weight decay, else do not compute momentum (float, optional): momentum factor (default: 0.9) use_nesterov (bool, optional): flag to use nesterov momentum (default: False) weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) ("\beta") eta (float, optional): LARS coefficient (default: 0.001) - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. - Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888 """ def __init__(self, params, lr, len_reduced, momentum=0.9, use_nesterov=False, weight_decay=0.0, classic_momentum=True, eta=0.001): self.epoch = 0 defaults = dict( lr=lr, momentum=momentum, use_nesterov=use_nesterov, weight_decay=weight_decay, classic_momentum=classic_momentum, eta=eta, len_reduced=len_reduced ) super(LARS, self).__init__(params, defaults) self.lr = lr self.momentum = momentum self.weight_decay = weight_decay self.use_nesterov = use_nesterov self.classic_momentum = classic_momentum self.eta = eta self.len_reduced = len_reduced def step(self, epoch=None, closure=None): loss = None if closure is not None: loss = closure() if epoch is None: epoch = self.epoch self.epoch += 1 for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] eta = group['eta'] learning_rate = group['lr'] # TODO: Hacky counter = 0 for p in group['params']: if p.grad is None: continue param = p.data grad = p.grad.data param_state = self.state[p] # TODO: This really hacky way needs to be improved. # Note Excluded are passed at the end of the list to are ignored if counter < self.len_reduced: grad += self.weight_decay * param # Create parameter for the momentum if "momentum_var" not in param_state: next_v = param_state["momentum_var"] = torch.zeros_like( p.data ) else: next_v = param_state["momentum_var"] if self.classic_momentum: trust_ratio = 1.0 # TODO: implementation of layer adaptation w_norm = torch.norm(param) g_norm = torch.norm(grad) device = g_norm.get_device() trust_ratio = torch.where(w_norm.ge(0), torch.where( g_norm.ge(0), (self.eta * w_norm / g_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() scaled_lr = learning_rate * trust_ratio grad_scaled = scaled_lr*grad next_v.mul_(momentum).add_(grad_scaled) if self.use_nesterov: update = (self.momentum * next_v) + (scaled_lr * grad) else: update = next_v p.data.add_(-update) # Not classic_momentum else: next_v.mul_(momentum).add_(grad) if self.use_nesterov: update = (self.momentum * next_v) + (grad) else: update = next_v trust_ratio = 1.0 # TODO: implementation of layer adaptation w_norm = torch.norm(param) v_norm = torch.norm(update) device = v_norm.get_device() trust_ratio = torch.where(w_norm.ge(0), torch.where( v_norm.ge(0), (self.eta * w_norm / v_norm), torch.Tensor([1.0]).to(device)), torch.Tensor([1.0]).to(device)).item() scaled_lr = learning_rate * trust_ratio p.data.add_(-scaled_lr * update) counter += 1 return loss #LARSWrapper from solo-learn repo... class LARSWrapper: def __init__( self, optimizer: Optimizer, eta: float = 1e-3, clip: bool = False, eps: float = 1e-8, exclude_bias_n_norm: bool = False, ): """Wrapper that adds LARS scheduling to any optimizer. This helps stability with huge batch sizes. Args: optimizer (Optimizer): torch optimizer. eta (float, optional): trust coefficient. Defaults to 1e-3. clip (bool, optional): clip gradient values. Defaults to False. eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8. exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars. Defaults to False. """ self.optim = optimizer self.eta = eta self.eps = eps self.clip = clip self.exclude_bias_n_norm = exclude_bias_n_norm # transfer optim methods self.state_dict = self.optim.state_dict self.load_state_dict = self.optim.load_state_dict self.zero_grad = self.optim.zero_grad self.add_param_group = self.optim.add_param_group self.__setstate__ = self.optim.__setstate__ # type: ignore self.__getstate__ = self.optim.__getstate__ # type: ignore self.__repr__ = self.optim.__repr__ # type: ignore @property def defaults(self): return self.optim.defaults @defaults.setter def defaults(self, defaults): self.optim.defaults = defaults @property # type: ignore def __class__(self): return Optimizer @property def state(self): return self.optim.state @state.setter def state(self, state): self.optim.state = state @property def param_groups(self): return self.optim.param_groups @param_groups.setter def param_groups(self, value): self.optim.param_groups = value @torch.no_grad() def step(self, closure=None): weight_decays = [] for group in self.optim.param_groups: weight_decay = group.get("weight_decay", 0) weight_decays.append(weight_decay) # reset weight decay group["weight_decay"] = 0 # update the parameters for p in group["params"]: if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm): self.update_p(p, group, weight_decay) # update the optimizer self.optim.step(closure=closure) # return weight decay control to optimizer for group_idx, group in enumerate(self.optim.param_groups): group["weight_decay"] = weight_decays[group_idx] def update_p(self, p, group, weight_decay): # calculate new norms p_norm = torch.norm(p.data) g_norm = torch.norm(p.grad.data) if p_norm != 0 and g_norm != 0: # calculate new lr new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) # clip lr if self.clip: new_lr = min(new_lr / group["lr"], 1) # update params with clipped lr p.grad.data += weight_decay * p.data p.grad.data *= new_lr ================================================ FILE: loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class contrastive_loss(nn.Module): def __init__(self): super().__init__() pass def forward(self,x,labels): #this function assums that positive logit is always the first element. #Which is true here loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1) return loss.mean() class SimCLR(nn.Module): def __init__(self,temperature=0.5,n_views=2,contrastive=False): super(SimCLR,self).__init__() self.temp = temperature self.n_views = n_views if contrastive: self.criterion = contrastive_loss() else: self.criterion = torch.nn.CrossEntropyLoss() def info_nce_loss(self,X): bs, n_dim = X.shape bs = int(bs/self.n_views) device = X.device labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() labels = labels.to(device) similarity_matrix = torch.matmul(X, X.T) # assert similarity_matrix.shape == ( # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) # assert similarity_matrix.shape == labels.shape # discard the main diagonal from both: labels and similarities matrix mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) labels = labels[~mask].view(labels.shape[0], -1) similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # assert similarity_matrix.shape == labels.shape # select and combine multiple positives positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # select only the negatives negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) logits = torch.cat([positives, negatives], dim=1) labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) logits = logits / self.temp return logits, labels def forward(self,X): logits, labels = self.info_nce_loss(X) loss = self.criterion(logits, labels) return loss class Z_loss(nn.Module): def __init__(self,): super().__init__() pass def forward(self,z): z_list = z.chunk(2,dim=0) z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean() z_sim_out = z_sim.clone().detach() return -z_sim, z_sim_out class TotalCodingRate(nn.Module): def __init__(self, eps=0.01): super(TotalCodingRate, self).__init__() self.eps = eps def compute_discrimn_loss(self, W): """Discriminative Loss.""" p, m = W.shape #[d, B] I = torch.eye(p,device=W.device) scalar = p / (m * self.eps) logdet = torch.logdet(I + scalar * W.matmul(W.T)) return logdet / 2. def forward(self,X): return - self.compute_discrimn_loss(X.T) class MaximalCodingRateReduction(torch.nn.Module): def __init__(self, eps=0.01, gamma=1): super(MaximalCodingRateReduction, self).__init__() self.eps = eps self.gamma = gamma def compute_discrimn_loss(self, W): """Discriminative Loss.""" p, m = W.shape I = torch.eye(p,device=W.device) scalar = p / (m * self.eps) logdet = torch.logdet(I + scalar * W.matmul(W.T)) return logdet / 2. def compute_compress_loss(self, W, Pi): p, m = W.shape k, _, _ = Pi.shape I = torch.eye(p,device=W.device).expand((k,p,p)) trPi = Pi.sum(2) + 1e-8 scale = (p/(trPi*self.eps)).view(k,1,1) W = W.view((1,p,m)) log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2))) compress_loss = (trPi.squeeze()*log_det/(2*m)).sum() return compress_loss def forward(self, X, Y, num_classes=None): #This function support Y as label integer or membership probablity. if len(Y.shape)==1: #if Y is a label vector if num_classes is None: num_classes = Y.max() + 1 Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device) for indx, label in enumerate(Y): Pi[label,0,indx] = 1 else: #if Y is a probility matrix if num_classes is None: num_classes = Y.shape[1] Pi = Y.T.reshape((num_classes,1,-1)) W = X.T discrimn_loss = self.compute_discrimn_loss(W) compress_loss = self.compute_compress_loss(W, Pi) total_loss = - discrimn_loss + self.gamma*compress_loss return total_loss, [discrimn_loss.item(), compress_loss.item()] ================================================ FILE: main.py ================================================ ############ ## Import ## ############ import argparse import torch.nn as nn import torch.optim as optim import os from torch.utils.data import DataLoader from model.model import encoder from dataset.datasets import load_dataset import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as F import torchvision.transforms.functional as FF from tqdm import tqdm import torch from torchvision.datasets import CIFAR10 from loss import TotalCodingRate from func import chunk_avg from lars import LARS, LARSWrapper from func import WeightedKNNClassifier import torch.optim.lr_scheduler as lr_scheduler from torch.cuda.amp import GradScaler, autocast ###################### ## Parsing Argument ## ###################### import argparse parser = argparse.ArgumentParser(description='Unsupervised Learning') parser.add_argument('--patch_sim', type=int, default=200, help='coefficient of cosine similarity (default: 200)') parser.add_argument('--tcr', type=int, default=1, help='coefficient of tcr (default: 1)') parser.add_argument('--num_patches', type=int, default=100, help='number of patches used in EMP-SSL (default: 100)') parser.add_argument('--arch', type=str, default="resnet18-cifar", help='network architecture (default: resnet18-cifar)') parser.add_argument('--bs', type=int, default=100, help='batch size (default: 100)') parser.add_argument('--lr', type=float, default=0.3, help='learning rate (default: 0.3)') parser.add_argument('--eps', type=float, default=0.2, help='eps for TCR (default: 0.2)') parser.add_argument('--msg', type=str, default="NONE", help='additional message for description (default: NONE)') parser.add_argument('--dir', type=str, default="EMP-SSL-Training", help='directory name (default: EMP-SSL-Training)') parser.add_argument('--data', type=str, default="cifar10", help='data (default: cifar10)') parser.add_argument('--epoch', type=int, default=30, help='max number of epochs to finish (default: 30)') args = parser.parse_args() print(args) num_patches = args.num_patches dir_name = f"./logs/{args.dir}/patchsim{args.patch_sim}_numpatch{args.num_patches}_bs{args.bs}_lr{args.lr}_{args.msg}" ##################### ## Helper Function ## ##################### def chunk_avg(x,n_chunks=2,normalize=False): x_list = x.chunk(n_chunks,dim=0) x = torch.stack(x_list,dim=0) if not normalize: return x.mean(0) else: return F.normalize(x.mean(0),dim=1) class Similarity_Loss(nn.Module): def __init__(self, ): super().__init__() pass def forward(self, z_list, z_avg): z_sim = 0 num_patch = len(z_list) z_list = torch.stack(list(z_list), dim=0) z_avg = z_list.mean(dim=0) z_sim = 0 for i in range(num_patch): z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean() z_sim = z_sim/num_patch z_sim_out = z_sim.clone().detach() return -z_sim, z_sim_out def cal_TCR(z, criterion, num_patches): z_list = z.chunk(num_patches,dim=0) loss = 0 for i in range(num_patches): loss += criterion(z_list[i]) loss = loss/num_patches return loss ###################### ## Prepare Training ## ###################### torch.multiprocessing.set_sharing_strategy('file_system') if args.data == "imagenet100" or args.data == "imagenet": train_dataset = load_dataset("imagenet", train=True, num_patch = num_patches) dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=8) else: train_dataset = load_dataset(args.data, train=True, num_patch = num_patches) dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True,num_workers=16) use_cuda = True device = torch.device("cuda" if use_cuda else "cpu") net = encoder(arch = args.arch) net = nn.DataParallel(net) net.cuda() opt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True) opt = LARSWrapper(opt,eta=0.005,clip=True,exclude_bias_n_norm=True,) scaler = GradScaler() if args.data == "imagenet-100": num_converge = (150000//args.bs)*args.epoch else: num_converge = (50000//args.bs)*args.epoch scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=num_converge, eta_min=0,last_epoch=-1) # Loss contractive_loss = Similarity_Loss() criterion = TotalCodingRate(eps=args.eps) ############## ## Training ## ############## def main(): for epoch in range(args.epoch): for step, (data, label) in tqdm(enumerate(dataloader)): net.zero_grad() opt.zero_grad() data = torch.cat(data, dim=0) data = data.cuda() z_proj = net(data) z_list = z_proj.chunk(num_patches, dim=0) z_avg = chunk_avg(z_proj, num_patches) #Contractive Loss loss_contract, _ = contractive_loss(z_list, z_avg) loss_TCR = cal_TCR(z_proj, criterion, num_patches) loss = args.patch_sim*loss_contract + args.tcr*loss_TCR loss.backward() opt.step() scheduler.step() model_dir = dir_name+"/save_models/" if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save(net.state_dict(), model_dir+str(epoch)+".pt") print("At epoch:", epoch, "loss similarity is", loss_contract.item(), ",loss TCR is:", (loss_TCR).item(), "and learning rate is:", opt.param_groups[0]['lr']) # Press the green button in the gutter to run the script. if __name__ == '__main__': main() # See PyCharm help at https://www.jetbrains.com/help/pycharm/ ================================================ FILE: mcr/loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class MCRGANloss(nn.Module): def __init__(self, gam1=1., gam2=1., gam3=1., eps=0.5, numclasses=1000, mode=0): super(MCRGANloss, self).__init__() self.num_class = numclasses self.train_mode = mode self.gam1 = gam1 self.gam2 = gam2 self.gam3 = gam3 self.eps = eps def forward(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop): # t = time.time() errD, empi = self.old_version(Z, Z_bar, real_label, ith_inner_loop, num_inner_loop) return errD, empi def old_version(self, Z, Z_bar, real_label, ith_inner_loop, num_inner_loop): if self.train_mode == 2: loss_z, _ = self.deltaR(Z, real_label, self.num_class) assert num_inner_loop >= 2 if (ith_inner_loop + 1) % num_inner_loop != 0: # print(f"{ith_inner_loop + 1}/{num_inner_loop}") # print("calculate delta R(z)") return loss_z, None loss_h, _ = self.deltaR(Z_bar, real_label, self.num_class) empi = [loss_z, loss_h] term3 = 0. for i in range(self.num_class): new_Z = torch.cat((Z[real_label == i], Z_bar[real_label == i]), 0) new_label = torch.cat( (torch.zeros_like(real_label[real_label == i]), torch.ones_like(real_label[real_label == i])) ) loss, em = self.deltaR(new_Z, new_label, 2) term3 += loss empi = empi + [term3] errD = self.gam1 * loss_z + self.gam2 * loss_h + self.gam3 * term3 elif self.train_mode == 1: print("has been dropped") raise NotImplementedError() elif self.train_mode == 0: new_Z = torch.cat((Z, Z_bar), 0) new_label = torch.cat((torch.zeros_like(real_label), torch.ones_like(real_label))) errD, empi = self.deltaR(new_Z, new_label, 2) else: raise ValueError() return errD, empi def debug(self, Z, Z_bar, real_label): print("===========================") def compute_discrimn_loss(self, Z): """Theoretical Discriminative Loss.""" d, n = Z.shape I = torch.eye(d).to(Z.device) scalar = d / (n * self.eps) logdet = torch.logdet(I + scalar * Z @ Z.T) return logdet / 2. def compute_compress_loss(self, Z, Pi): """Theoretical Compressive Loss.""" d, n = Z.shape I = torch.eye(d).to(Z.device) compress_loss = [] scalars = [] for j in range(Pi.shape[1]): Z_ = Z[:, Pi[:, j] == 1] trPi = Pi[:, j].sum() + 1e-8 scalar = d / (trPi * self.eps) log_det = torch.logdet(I + scalar * Z_ @ Z_.T) compress_loss.append(log_det) scalars.append(trPi / (2 * n)) return compress_loss, scalars def deltaR(self, Z, Y, num_classes): if num_classes is None: num_classes = Y.max() + 1 #print("classes:", num_classes) Pi = F.one_hot(Y, num_classes).to(Z.device) discrimn_loss = self.compute_discrimn_loss(Z.T) compress_loss, scalars = self.compute_compress_loss(Z.T, Pi) compress_term = 0. for z, s in zip(compress_loss, scalars): compress_term += s * z total_loss = discrimn_loss - compress_term return -total_loss, (discrimn_loss, compress_term, compress_loss, scalars) def gumb_compress_loss(self, Z, P): d, n = Z.shape I = torch.eye(d).to(Z.device) compress_loss = 0. for j in range(self.num_class): #P[:, j:j+1][P[:, j:j+1]