Repository: chenziwenhaoshuai/Vision-KAN
Branch: main
Commit: be2902606a13
Files: 19
Total size: 107.1 KB
Directory structure:
gitextract_7m6sx636/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── augment.py
├── datasets.py
├── ekan.py
├── engine.py
├── fasterkan.py
├── hubconf.py
├── kit.py
├── losses.py
├── main.py
├── minimal_example.py
├── models_kan.py
├── pyproject.toml
├── requirements.txt
├── samplers.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
.ipynb_checkpoints/
Datasets/
__pycache__/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 chenziwenhaoshuai, tommarvoloriddle, chouheiwa
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Vision-KAN 🚀
Welcome to **Vision-KAN**! We are exploring the exciting possibility of [KAN](https://github.com/KindXiaoming/pykan) replacing MLP in Vision Transformer. Due to GPU resource constraints, this project may experience delays, but we'll keep you updated with any new developments here! 📅✨
## Installation 🛠️
To install this package, simply run:
```bash
pip install VisionKAN
```
## Minimal Example 💡
Here's a quick example to get you started:
```python
from VisionKAN import create_model, train_one_epoch, evaluate
KAN_model = create_model(
model_name='deit_tiny_patch16_224_KAN',
pretrained=False,
hdim_kan=192,
num_classes=100,
drop_rate=0.0,
drop_path_rate=0.05,
img_size=224,
batch_size=144
)
```
## Performance Overview 📊
### Baseline Models
| Dataset | MLP Hidden Dim | Model | Date | Epoch | Top-1 | Top-5 | Checkpoint |
|---------|----------------|---------------------|------|-------|-------|-------|------------|
| ImageNet 1k | 768 | DeiT-tiny (baseline) | - | 300 | 72.2 | 91.1 | - |
| CIFAR-100 | 192 | DeiT-tiny (baseline) | 2024.5.25 | 300(stop) | 84.94 | 96.53 | [Checkpoint](https://drive.google.com/drive/folders/1hPrnfI5CKMgwM6lgSrFUwvMQYsjtjg3A?usp=drive_link) |
| CIFAR-100 | 384 | DeiT-small (baseline) | 2024.5.25 | 300(stop) | 86.49 | 96.17 | [Checkpoint](https://drive.google.com/drive/folders/1ZSl2ojZUQRkIsZzJ0w5rahOTAv4IiZCt?usp=drive_link) |
| CIFAR-100 | 768 | DeiT-base (baseline) | 2024.5.25 | 300(stop) | 86.54 | 96.16 | [Checkpoint](https://drive.google.com/drive/folders/14kLdJDy11zv_mC35JvbcPCdoXvrHspNK?usp=sharing) |
### Vision-KAN Models
| Dataset | KAN Hidden Dim | Model | Date | Epoch | Top-1 | Top-5 | Checkpoint |
|---------|----------------|-----------|----------|-----------|-------|-------|------------|
| ImageNet 1k | 20 | Vision-KAN | 2024.5.16 | 37(stop) | 36.34 | 61.48 | - |
| ImageNet 1k | 192 | Vision-KAN | 2024.5.25 | 346(stop) | 64.87 | 86.14 | [Checkpoint](https://pan.baidu.com/s/117ox7oh6zzXLwPMmQ6od1Q?pwd=y1vw) |
| ImageNet 1k | 768 | Vision-KAN | 2024.6.2 | 154(training) | 62.90 | 85.03 | - |
| CIFAR-100 | 192 | Vision-KAN | 2024.5.25 | 300(stop) | 73.17 | 93.307 | [Checkpoint](https://drive.google.com/drive/folders/19WPq6bZ9NgX-WxD7qXSTKiHc5D6P8jQP?usp=sharing) |
| CIFAR-100 | 384 | Vision-KAN | 2024.5.25 | 300(stop) | 78.69 | 94.73 | [Checkpoint](https://drive.google.com/drive/folders/1Uhj4yV0HZRQkPFUerxy88B19N1eDdgsc?usp=drive_link) |
| CIFAR-100 | 768 | Vision-KAN | 2024.5.29 | 300(stop) | 79.82 | 95.42 | [Checkpoint](https://drive.google.com/drive/folders/1FT55_6tDO_a135sQKBDn409fDdXvCi4N?usp=drive_link) |
## Latest News 📰
- **5.7.2024**: Released the current Vision KAN code! 🚀 We used efficient KAN to replace the MLP layer in the Transformer block and are pre-training the Tiny model on ImageNet 1k. Updates will be reflected in the table.
- **5.14.2024**: The model is starting to converge! We’re using [192, 20, 192] for input, hidden, and output dimensions.
- **5.15.2024**: Switched from [efficient kan](https://github.com/Blealtan/efficient-kan) to [faster kan](https://github.com/AthanasiosDelis/faster-kan) to double the training speed! 🚀
- **5.16.2024**: Convergence appears to be bottlenecked; considering adjusting the KAN hidden layer size from 20 to 192.
- **5.22.2024**: Fixed Timm version dependency issues and cleaned up the code! 🧹
- **5.24.2024**: Loss decline is slowing, nearing final results! 🔍
- **5.25.2024**: The model with 192 hidden layers is approaching convergence! 🎉 Released the best checkpoint of VisionKAN.
## Architecture 🏗️
We utilized [DeiT](https://github.com/facebookresearch/deit) as the baseline for Vision KAN development. Huge thanks to Meta and MIT for their incredible work! 🙌
## Star History 🌟
[](https://star-history.com/#chenziwenhaoshuai/Vision-KAN&Date)
## Citation 📑
If you are using our work, please cite:
```bibtex
@misc{VisionKAN2024,
author = {Ziwen Chen and Gundavarapu and WU DI},
title = {Vision-KAN: Exploring the Possibility of KAN Replacing MLP in Vision Transformer},
year = {2024},
howpublished = {\url{https://github.com/chenziwenhaoshuai/Vision-KAN.git}},
}
```
================================================
FILE: __init__.py
================================================
# __init__.py
from .models_kan import create_model
from .engine import train_one_epoch, evaluate
__all__ = ['create_model', 'train_one_epoch', 'evaluate']
================================================
FILE: augment.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA(https://github.com/rwightman/pytorch-image-models)
"""
import torch
from torchvision import transforms
from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
import numpy as np
from torchvision import datasets, transforms
import random
from PIL import ImageFilter, ImageOps
import torchvision.transforms.functional as TF
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
img = img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
return img
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class gray_scale(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class horizontal_flip(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2,activate_pred=False):
self.p = p
self.transf = transforms.RandomHorizontalFlip(p=1.0)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def new_data_aug_generator(args = None):
img_size = args.input_size
remove_random_resized_crop = args.src
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
primary_tfl = []
scale=(0.08, 1.0)
interpolation='bicubic'
if remove_random_resized_crop:
primary_tfl = [
transforms.Resize(img_size, interpolation=3),
transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
transforms.RandomHorizontalFlip()
]
else:
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
Solarization(p=1.0),
GaussianBlur(p=1.0)])]
if args.color_jitter is not None and not args.color_jitter==0:
secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
final_tfl = [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
================================================
FILE: datasets.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import os
import json
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from torchvision.transforms import InterpolationMode
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
class INatDataset(ImageFolder):
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
category='name', loader=default_loader):
self.transform = transform
self.loader = loader
self.target_transform = target_transform
self.year = year
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
with open(path_json) as json_file:
data = json.load(json_file)
with open(os.path.join(root, 'categories.json')) as json_file:
data_catg = json.load(json_file)
path_json_for_targeter = os.path.join(root, f"train{year}.json")
with open(path_json_for_targeter) as json_file:
data_for_targeter = json.load(json_file)
targeter = {}
indexer = 0
for elem in data_for_targeter['annotations']:
king = []
king.append(data_catg[int(elem['category_id'])][category])
if king[0] not in targeter.keys():
targeter[king[0]] = indexer
indexer += 1
self.nb_classes = len(targeter)
self.samples = []
for elem in data['images']:
cut = elem['file_name'].split('/')
target_current = int(cut[2])
path_current = os.path.join(root, cut[0], cut[2], cut[3])
categors = data_catg[target_current]
target_current_true = targeter[categors[category]]
self.samples.append((path_current, target_current_true))
# __getitem__ and __len__ inherited from ImageFolder
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'TinyIMNET':
root = os.path.join(args.data_path, 'tiny-imagenet-200/train' if is_train else 'tiny-imagenet-200/val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 100
elif args.data_set == 'INAT':
dataset = INatDataset(args.data_path, train=is_train, year=2018,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
elif args.data_set == 'INAT19':
dataset = INatDataset(args.data_path, train=is_train, year=2019,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
size = int(args.input_size / args.eval_crop_ratio)
t.append(
transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
================================================
FILE: ekan.py
================================================
import torch
import torch.nn.functional as F
import math
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=False,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
================================================
FILE: engine.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from losses import DistillationLoss
import utils
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True, args = None):
model.train(set_training_mode)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 50
if args.cosub:
criterion = torch.nn.BCEWithLogitsLoss()
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if args.cosub:
samples = torch.cat((samples,samples),dim=0)
if args.bce_loss:
targets = targets.gt(0.0).type(targets.dtype)
with torch.cuda.amp.autocast():
outputs = model(samples)
if not args.cosub:
loss = criterion(samples, outputs, targets)
else:
outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
loss = 0.25 * criterion(outputs[0], targets)
loss = loss + 0.25 * criterion(outputs[1], targets)
loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: fasterkan.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
class SplineLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
self.init_scale = init_scale
super().__init__(in_features, out_features, bias=False, **kw)
def reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.weight) # Using Xavier Uniform initialization
class ReflectionalSwitchFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
exponent: int = 2,
denominator: float = 0.33, # larger denominators lead to smoother basis
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator # or (grid_max - grid_min) / (num_grids - 1)
# self.exponent = exponent
self.inv_denominator = 1 / self.denominator # Cache the inverse of the denominator
def forward(self, x):
diff = (x[..., None] - self.grid)
diff_mul = diff.mul(self.inv_denominator)
diff_tanh = torch.tanh(diff_mul)
diff_pow = -diff_tanh.mul(diff_tanh)
diff_pow += 1
# diff_pow *= 0.667
return diff_pow # Replace pow with multiplication for squaring
class FasterKANLayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
exponent: int = 2,
denominator: float = 0.33,
use_base_update: bool = True,
base_activation=F.silu,
spline_weight_init_scale: float = 0.1,
) -> None:
super().__init__()
self.layernorm = nn.LayerNorm(input_dim)
self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, denominator)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
# self.use_base_update = use_base_update
# if use_base_update:
# self.base_activation = base_activation
# self.base_linear = nn.Linear(input_dim, output_dim)
def forward(self, x, time_benchmark=False):
if not time_benchmark:
spline_basis = self.rbf(self.layernorm(x)).view(x.shape[0], -1)
# print("spline_basis:", spline_basis.shape)
else:
spline_basis = self.rbf(x).view(x.shape[0], -1)
# print("spline_basis:", spline_basis.shape)
# print("-------------------------")
# ret = 0
ret = self.spline_linear(spline_basis)
# print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
# print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
# print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
# print("ret:", ret.shape)
# print("-------------------------")
# if self.use_base_update:
# base = self.base_linear(self.base_activation(x))
# print("self.base_activation(x):", self.base_activation(x).shape)
# print("base:", base.shape)
# print("@@@@@@@@@")
# ret += base
return ret
# spline_basis = spline_basis.reshape(x.shape[0], -1) # Reshape to [batch_size, input_dim * num_grids]
# print("spline_basis:", spline_basis.shape)
# spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0]) # Reshape to [input_dim * num_grids, output_dim]
# print("spline_weight:", spline_weight.shape)
# spline = torch.matmul(spline_basis, spline_weight) # Resulting shape: [batch_size, output_dim]
# print("-------------------------")
# print("Base shape:", base.shape)
# print("Spline shape:", spline.shape)
# print("@@@@@@@@@")
class FasterKAN(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
exponent: int = 2,
denominator: float = 0.33,
use_base_update: bool = True,
base_activation=F.silu,
spline_weight_init_scale: float = 0.667,
) -> None:
super().__init__()
self.layers = nn.ModuleList([
FasterKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
exponent=exponent,
denominator=denominator,
use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
================================================
FILE: hubconf.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from models import *
from cait_models import *
from resmlp_models import *
#from patchconvnet_models import *
dependencies = ["torch", "torchvision", "timm"]
================================================
FILE: kit.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
from ekan import KAN
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
__all__ = [
'deit_tiny_patch16_LS'
]
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp
, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
# self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.kan = KAN([dim, 20, dim])
def forward(self, x):
b,t,d = x.shape
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1,x.shape[-1])).reshape(b,t,d))
return x
class Layer_scale_init_Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp
, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.kan = KAN([dim, 20, dim])
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
b,t,d = x.shape
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
# x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1,x.shape[-1])).reshape(b,t,d))
return x
class Layer_scale_init_Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp
, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + self.drop_path(
self.gamma_1_1 * self.attn1(self.norm11(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + self.drop_path(
self.gamma_2_1 * self.mlp1(self.norm21(x)))
return x
class Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp
, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
return x
class hMLP_stem(nn.Module):
""" hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.SyncBatchNorm):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = torch.nn.Sequential(*[nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
norm_layer(embed_dim),
])
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class vit_models(nn.Module):
""" Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,
block_layers=Block,
Patch_layer=PatchEmbed, act_layer=nn.GELU,
Attention_block=Attention, Mlp_block=Mlp,
dpr_constant=True, init_scale=1e-4,
mlp_ratio_clstk=4.0, **kwargs):
super().__init__()
self.dropout_rate = drop_rate
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = Patch_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=0.0, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, Attention_block=Attention_block, Mlp_block=Mlp_block, init_values=init_scale)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def get_num_layers(self):
return len(self.blocks)
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
x = torch.cat((cls_tokens, x), dim=1)
for i, blk in enumerate(self.blocks):
x = blk(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
if self.dropout_rate:
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
x = self.head(x)
return x
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
@register_model
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
return model
@register_model
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
model.default_cfg = _cfg()
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_' + str(img_size) + '_'
if pretrained_21k:
name += '21k.pth'
else:
name += '1k.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
model.default_cfg = _cfg()
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_' + str(img_size) + '_'
if pretrained_21k:
name += '21k.pth'
else:
name += '1k.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_base_' + str(img_size) + '_'
if pretrained_21k:
name += '21k.pth'
else:
name += '1k.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_large_' + str(img_size) + '_'
if pretrained_21k:
name += '21k.pth'
else:
name += '1k.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
if pretrained:
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_huge_' + str(img_size) + '_'
if pretrained_21k:
name += '21k_v1.pth'
else:
name += '1k_v1.pth'
checkpoint = torch.hub.load_state_dict_from_url(
url=name,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1280, depth=52, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
return model
@register_model
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1280, depth=26, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)
return model
@register_model
def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
return model
@register_model
def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
return model
@register_model
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
return model
@register_model
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
# model.default_cfg = _cfg()
return model
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
@register_model
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
return model
@register_model
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
@register_model
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)
return model
@register_model
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paralx2, **kwargs)
return model
@register_model
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)
return model
@register_model
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paralx2, **kwargs)
return model
@register_model
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)
return model
@register_model
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def create_model(model_name,**kwargs):
if model_name in __all__:
create_fn = globals()[model_name]
model = create_fn(**kwargs)
model.default_cfg = _cfg()
return model
else:
raise RuntimeError('Unknown model (%s)' % model_name)
if __name__ == '__main__':
batch_size = 5
model = deit_tiny_patch16_LS().cuda()
print(model(torch.randn(5, 3, 224, 224).cuda()).shape)
================================================
FILE: losses.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Implements the knowledge distillation loss
"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
#We provide the teacher's targets in log probability because we use log_target=True
#(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
#but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
#We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
#But we also experiments output_kd.size(0)
#see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss
================================================
FILE: main.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from timm.data import Mixup
# from timm.models import create_model
from models_kan import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from datasets import build_dataset
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
from samplers import RASampler
from augment import new_data_aug_generator
# import models
# import models_v2
import kit
import utils
def get_args_parser():
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
parser.add_argument('--batch-size', default=144, type=int)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--bce-loss', action='store_true')
parser.add_argument('--unscale-lr', action='store_true',default=True)
# Model parameters
parser.add_argument('--model', default='deit_tiny_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input-size', default=224, type=int, help='images input size')
parser.add_argument('--hdim_kan', default=192, type=int, help='hidden dimension for KAN')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-path', type=float, default=0.05, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--model-ema', action='store_true')
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=True)
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
help='learning rate (default: 5e-4)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
help='Color jitter factor (default: 0.3)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + \
"(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0., help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--repeated-aug', action='store_true')
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)
parser.add_argument('--train-mode', action='store_true')
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
parser.set_defaults(train_mode=True)
parser.add_argument('--ThreeAugment', default=True, action='store_true') #3augment
parser.add_argument('--src', action='store_true') #simple random crop
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str, default='')
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# * Cosub params
parser.add_argument('--cosub', action='store_true')
# * Finetuning params
parser.add_argument('--finetune', default='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', help='finetune from checkpoint')
parser.add_argument('--attn-only', action='store_true')
# Dataset parameters
parser.add_argument('--data-path', default='/scratch/sg7729/KAN/Vision-KAN/Datasets', type=str,
help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19','TinyIMNET'],
type=str, help='Image Net dataset path')
parser.add_argument('--inat-category', default='name',
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
type=str, help='semantic granularity')
parser.add_argument('--output_dir', default='./output',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--eval-crop-ratio', default=1, type=float, help="Crop ratio for evaluation")
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
# parser.add_argument('--world_size', default=1, type=int,
# help='number of distributed processes')
# parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def main(args):
# utils.init_distributed_mode(args)
print(args)
if args.distillation_type != 'none' and args.finetune and not args.eval:
raise NotImplementedError("Finetuning with distillation not yet supported")
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)
# if args.distributed:
# num_tasks = utils.get_world_size()
# global_rank = utils.get_rank()
# if args.repeated_aug:
# sampler_train = RASampler(
# dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
# )
# else:
# sampler_train = torch.utils.data.DistributedSampler(
# dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
# )
# if args.dist_eval:
# if len(dataset_val) % num_tasks != 0:
# print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
# 'This will slightly alter validation results as extra duplicate entries are added to achieve '
# 'equal num of samples per-process.')
# sampler_val = torch.utils.data.DistributedSampler(
# dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
# else:
# sampler_val = torch.utils.data.SequentialSampler(dataset_val)
# else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if args.ThreeAugment:
data_loader_train.dataset.transform = new_data_aug_generator(args)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
# drop_block_rate=None,
img_size=args.input_size,
batch_size=args.batch_size
)
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
model.load_state_dict(checkpoint_model, strict=False)
if args.attn_only:
for name_p,p in model.named_parameters():
if '.attn.' in name_p:
p.requires_grad = True
else:
p.requires_grad = False
try:
model.head.weight.requires_grad = True
model.head.bias.requires_grad = True
except:
model.fc.weight.requires_grad = True
model.fc.bias.requires_grad = True
try:
model.pos_embed.requires_grad = True
except:
print('no position encoding')
try:
for p in model.patch_embed.parameters():
p.requires_grad = False
except:
print('no patch embed')
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
if not args.unscale_lr:
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
args.lr = linear_scaled_lr
optimizer = create_optimizer(args, model_without_ddp)
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
criterion = LabelSmoothingCrossEntropy()
if mixup_active:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
if args.bce_loss:
criterion = torch.nn.BCEWithLogitsLoss()
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()
# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.model_ema:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
lr_scheduler.step(args.start_epoch)
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, model_ema, mixup_fn,
set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
args = args,
)
lr_scheduler.step(epoch)
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if max_accuracy < test_stats["acc1"]:
max_accuracy = test_stats["acc1"]
if args.output_dir:
checkpoint_paths = [output_dir / 'best_checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
# gs)
================================================
FILE: minimal_example.py
================================================
from models_kan import create_model
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from engine import train_one_epoch, evaluate
KAN_model = create_model(
model_name='deit_tiny_patch16_224_KAN',
pretrained=False,
hdim_kan=192,
num_classes=10,
drop_rate=0.0,
drop_path_rate=0.05,
img_size=32,
batch_size=144
)
# dataset CIFAR10
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=144,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=144, shuffle=False, num_workers=2)
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
#optimizer
optimizer = optim.SGD(KAN_model.parameters(), lr=0.001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
KAN_model.to(device)
#train using engine.py
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = KAN_model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
# evaluate
test_stats = evaluate(testloader, KAN_model, device=device)
print(f"Accuracy of the network on the {len(testset)} test images: {test_stats['acc1']:.1f}%")
print('Finished Training')
================================================
FILE: models_kan.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg, Block, Attention
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from fasterkan import FasterKAN as KAN
__all__KAN = [
'deit_base_patch16_224_KAN', 'deit_small_patch16_224_KAN',
'deit_base_patch16_384_KAN', 'deit_tiny_patch16_224_KAN',
'deit_tiny_distilled_patch16_224_KAN', 'deit_base_distilled_patch16_224_KAN',
'deit_small_distilled_patch16_224_KAN', 'deit_base_distilled_patch16_384_KAN']
__all__ViT = [
'deit_base_patch16_224_ViT', 'deit_small_patch16_224_ViT',
'deit_base_patch16_384_ViT', 'deit_tiny_patch16_224_ViT',
'deit_tiny_distilled_patch16_224_ViT', 'deit_base_distilled_patch16_224_ViT',
'deit_small_distilled_patch16_224_ViT', 'deit_base_distilled_patch16_384_ViT']
class kanBlock(Block):
def __init__(self, dim, num_heads=8, hdim_kan=192, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__(dim, num_heads)
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
# self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.kan = KAN([dim, hdim_kan, dim])
def forward(self, x):
b, t, d = x.shape
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1, x.shape[-1])).reshape(b, t, d))
return x
class VisionKAN(VisionTransformer):
def __init__(self, *args, num_heads=8, batch_size=16, **kwargs):
if 'hdim_kan' in kwargs:
self.hdim_kan = kwargs['hdim_kan']
del kwargs['hdim_kan']
else:
self.hdim_kan = 192
super().__init__(*args, **kwargs)
self.num_heads = num_heads
# For newer version timm they don't save the depth to self.depth, so we need to check it
try:
self.depth
except AttributeError:
if 'depth' in kwargs:
self.depth = kwargs['depth']
else:
self.depth = 12
block_list = [
kanBlock(dim=self.embed_dim, num_heads=self.num_heads, hdim_kan=self.hdim_kan)
for i in range(self.depth)
]
# check the origin type of the block is torch.nn.modules.container.Sequential
# if the origin type is torch.nn.modules.container.Sequential, then we need to convert it to a list
if type(self.blocks) == nn.Sequential:
self.blocks = nn.Sequential(*block_list)
elif type(self.blocks) == nn.ModuleList:
self.blocks = nn.ModuleList(block_list)
class DistilledVisionTransformer(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def create_kan(model_name, pretrained, **kwargs):
if model_name == 'deit_tiny_patch16_224_KAN':
model = VisionKAN(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_small_patch16_224_KAN':
model = VisionKAN(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_patch16_224_KAN':
model = VisionKAN(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_patch16_384_KAN':
model = VisionKAN(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_tiny_distilled_patch16_224_KAN':
raise RuntimeError('Distilled models are not yet implmented in KAN')
elif model_name == 'deit_small_distilled_patch16_224_KAN':
raise RuntimeError('Distilled models are not yet implmented in KAN')
elif model_name == 'deit_base_distilled_patch16_224_KAN':
raise RuntimeError('Distilled models are not yet implmented in KAN')
def create_ViT(model_name, pretrained, **kwargs):
if 'batch_size' in kwargs:
del kwargs['batch_size']
if model_name == 'deit_base_patch16_224_ViT':
model = VisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_small_patch16_224_ViT':
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_patch16_224_ViT':
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_patch16_384_ViT':
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_tiny_distilled_patch16_224_ViT':
model = DistilledVisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_small_distilled_patch16_224_ViT':
model = DistilledVisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_distilled_patch16_224_ViT':
model = DistilledVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
elif model_name == 'deit_base_distilled_patch16_384_ViT':
model = DistilledVisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
def create_model(model_name,**kwargs):
pretrained = kwargs['pretrained'] if 'pretrained' in kwargs else False
if 'pretrained' in kwargs:
del kwargs['pretrained']
print(kwargs)
if model_name in __all__KAN:
model = create_kan(model_name, pretrained, **kwargs)
model.default_cfg = _cfg()
return model
elif model_name in __all__ViT:
model = create_ViT(model_name, pretrained, **kwargs)
model.default_cfg = _cfg()
return model
else:
raise RuntimeError('Unknown model (%s)' % model_name)
if __name__ == '__main__':
model = deit_tiny_patch16_224().cuda()
img = torch.randn(5, 3, 224, 224).cuda()
out = model(img)
print(out.shape)
================================================
FILE: pyproject.toml
================================================
# pyproject.toml
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "viskan"
version = "0.1.0"
description = "Viskan"
authors = ["Saaketh Koundinya <three2saki@gmail.com>"]
[tool.poetry.dependencies]
python >= "^3.6"
torch >= "1.13.1"
torchvision >= "0.8.1"
timm >= "0.3.2"
[tool.poetry.dev-dependencies]
# Add any development dependencies here if needed
================================================
FILE: requirements.txt
================================================
torch==1.13.1
torchvision==0.8.1
timm==0.3.2
================================================
FILE: samplers.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.distributed as dist
import math
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
================================================
FILE: utils.py
================================================
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import io
import os
import time
from collections import defaultdict, deque
import datetime
import torch
import torch.distributed as dist
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def _load_checkpoint_for_ema(model_ema, checkpoint):
"""
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
"""
mem_file = io.BytesIO()
torch.save({'state_dict_ema':checkpoint}, mem_file)
mem_file.seek(0)
model_ema._load_checkpoint(mem_file)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
gitextract_7m6sx636/ ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── augment.py ├── datasets.py ├── ekan.py ├── engine.py ├── fasterkan.py ├── hubconf.py ├── kit.py ├── losses.py ├── main.py ├── minimal_example.py ├── models_kan.py ├── pyproject.toml ├── requirements.txt ├── samplers.py └── utils.py
SYMBOL INDEX (140 symbols across 11 files)
FILE: augment.py
class GaussianBlur (line 24) | class GaussianBlur(object):
method __init__ (line 28) | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
method __call__ (line 33) | def __call__(self, img):
class Solarization (line 45) | class Solarization(object):
method __init__ (line 49) | def __init__(self, p=0.2):
method __call__ (line 52) | def __call__(self, img):
class gray_scale (line 58) | class gray_scale(object):
method __init__ (line 62) | def __init__(self, p=0.2):
method __call__ (line 66) | def __call__(self, img):
class horizontal_flip (line 74) | class horizontal_flip(object):
method __init__ (line 78) | def __init__(self, p=0.2,activate_pred=False):
method __call__ (line 82) | def __call__(self, img):
function new_data_aug_generator (line 90) | def new_data_aug_generator(args = None):
FILE: datasets.py
class INatDataset (line 13) | class INatDataset(ImageFolder):
method __init__ (line 14) | def __init__(self, root, train=True, year=2018, transform=None, target...
function build_dataset (line 56) | def build_dataset(is_train, args):
function build_transform (line 82) | def build_transform(is_train, args):
FILE: ekan.py
class KANLinear (line 6) | class KANLinear(torch.nn.Module):
method __init__ (line 7) | def __init__(
method reset_parameters (line 56) | def reset_parameters(self):
method b_splines (line 78) | def b_splines(self, x: torch.Tensor):
method curve2coeff (line 113) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
method scaled_spline_weight (line 146) | def scaled_spline_weight(self):
method forward (line 153) | def forward(self, x: torch.Tensor):
method update_grid (line 164) | def update_grid(self, x: torch.Tensor, margin=0.01):
method regularization_loss (line 212) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
class KAN (line 235) | class KAN(torch.nn.Module):
method __init__ (line 236) | def __init__(
method forward (line 269) | def forward(self, x: torch.Tensor, update_grid=False):
method regularization_loss (line 276) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
FILE: engine.py
function train_one_epoch (line 19) | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
function evaluate (line 83) | def evaluate(data_loader, model, device):
FILE: fasterkan.py
class SplineLinear (line 8) | class SplineLinear(nn.Linear):
method __init__ (line 9) | def __init__(self, in_features: int, out_features: int, init_scale: fl...
method reset_parameters (line 13) | def reset_parameters(self) -> None:
class ReflectionalSwitchFunction (line 17) | class ReflectionalSwitchFunction(nn.Module):
method __init__ (line 18) | def __init__(
method forward (line 33) | def forward(self, x):
class FasterKANLayer (line 43) | class FasterKANLayer(nn.Module):
method __init__ (line 44) | def __init__(
method forward (line 66) | def forward(self, x, time_benchmark=False):
class FasterKAN (line 103) | class FasterKAN(nn.Module):
method __init__ (line 104) | def __init__(
method forward (line 131) | def forward(self, x):
FILE: kit.py
class Attention (line 17) | class Attention(nn.Module):
method __init__ (line 19) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, at...
method forward (line 30) | def forward(self, x):
class Block (line 47) | class Block(nn.Module):
method __init__ (line 49) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 63) | def forward(self, x):
class Layer_scale_init_Block (line 70) | class Layer_scale_init_Block(nn.Module):
method __init__ (line 73) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 89) | def forward(self, x):
class Layer_scale_init_Block_paralx2 (line 97) | class Layer_scale_init_Block_paralx2(nn.Module):
method __init__ (line 100) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 122) | def forward(self, x):
class Block_paralx2 (line 130) | class Block_paralx2(nn.Module):
method __init__ (line 133) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 151) | def forward(self, x):
class hMLP_stem (line 157) | class hMLP_stem(nn.Module):
method __init__ (line 163) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
method forward (line 181) | def forward(self, x):
class vit_models (line 187) | class vit_models(nn.Module):
method __init__ (line 193) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
method _init_weights (line 233) | def _init_weights(self, m):
method no_weight_decay (line 243) | def no_weight_decay(self):
method get_classifier (line 246) | def get_classifier(self):
method get_num_layers (line 249) | def get_num_layers(self):
method reset_classifier (line 252) | def reset_classifier(self, num_classes, global_pool=''):
method forward_features (line 256) | def forward_features(self, x):
method forward (line 272) | def forward(self, x):
function deit_tiny_patch16_LS (line 286) | def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=...
function deit_small_patch16_LS (line 295) | def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k...
function deit_medium_patch16_LS (line 317) | def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21...
function deit_base_patch16_LS (line 338) | def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=...
function deit_large_patch16_LS (line 358) | def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k...
function deit_huge_patch14_LS (line 378) | def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=...
function deit_huge_patch14_52_LS (line 398) | def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_2...
function deit_huge_patch14_26x2_LS (line 407) | def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained...
function deit_Giant_48x2_patch14_LS (line 416) | def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretraine...
function deit_giant_40x2_patch14_LS (line 425) | def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretraine...
function deit_Giant_48_patch14_LS (line 433) | def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_...
function deit_giant_40_patch14_LS (line 441) | def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_...
function deit_small_patch16_36_LS (line 453) | def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_...
function deit_small_patch16_36 (line 462) | def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k...
function deit_small_patch16_18x2_LS (line 471) | def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretraine...
function deit_small_patch16_18x2 (line 480) | def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_2...
function deit_base_patch16_18x2_LS (line 489) | def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained...
function deit_base_patch16_18x2 (line 498) | def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21...
function deit_base_patch16_36x1_LS (line 507) | def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained...
function deit_base_patch16_36x1 (line 516) | def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21...
function create_model (line 522) | def create_model(model_name,**kwargs):
FILE: losses.py
class DistillationLoss (line 10) | class DistillationLoss(torch.nn.Module):
method __init__ (line 15) | def __init__(self, base_criterion: torch.nn.Module, teacher_model: tor...
method forward (line 25) | def forward(self, inputs, outputs, labels):
FILE: main.py
function get_args_parser (line 34) | def get_args_parser():
function main (line 195) | def main(args):
FILE: models_kan.py
class kanBlock (line 26) | class kanBlock(Block):
method __init__ (line 28) | def __init__(self, dim, num_heads=8, hdim_kan=192, mlp_ratio=4., qkv_b...
method forward (line 41) | def forward(self, x):
class VisionKAN (line 48) | class VisionKAN(VisionTransformer):
method __init__ (line 49) | def __init__(self, *args, num_heads=8, batch_size=16, **kwargs):
class DistilledVisionTransformer (line 80) | class DistilledVisionTransformer(VisionTransformer):
method __init__ (line 81) | def __init__(self, *args, **kwargs):
method forward_features (line 92) | def forward_features(self, x):
method forward (line 111) | def forward(self, x):
function create_kan (line 121) | def create_kan(model_name, pretrained, **kwargs):
function create_ViT (line 184) | def create_ViT(model_name, pretrained, **kwargs):
function create_model (line 292) | def create_model(model_name,**kwargs):
FILE: samplers.py
class RASampler (line 8) | class RASampler(torch.utils.data.Sampler):
method __init__ (line 16) | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True...
method __iter__ (line 38) | def __iter__(self):
method __len__ (line 60) | def __len__(self):
method set_epoch (line 63) | def set_epoch(self, epoch):
FILE: utils.py
class SmoothedValue (line 18) | class SmoothedValue(object):
method __init__ (line 23) | def __init__(self, window_size=20, fmt=None):
method update (line 31) | def update(self, value, n=1):
method synchronize_between_processes (line 36) | def synchronize_between_processes(self):
method median (line 50) | def median(self):
method avg (line 55) | def avg(self):
method global_avg (line 60) | def global_avg(self):
method max (line 64) | def max(self):
method value (line 68) | def value(self):
method __str__ (line 71) | def __str__(self):
class MetricLogger (line 80) | class MetricLogger(object):
method __init__ (line 81) | def __init__(self, delimiter="\t"):
method update (line 85) | def update(self, **kwargs):
method __getattr__ (line 92) | def __getattr__(self, attr):
method __str__ (line 100) | def __str__(self):
method synchronize_between_processes (line 108) | def synchronize_between_processes(self):
method add_meter (line 112) | def add_meter(self, name, meter):
method log_every (line 115) | def log_every(self, iterable, print_freq, header=None):
function _load_checkpoint_for_ema (line 162) | def _load_checkpoint_for_ema(model_ema, checkpoint):
function setup_for_distributed (line 172) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 187) | def is_dist_avail_and_initialized():
function get_world_size (line 195) | def get_world_size():
function get_rank (line 201) | def get_rank():
function is_main_process (line 207) | def is_main_process():
function save_on_master (line 211) | def save_on_master(*args, **kwargs):
function init_distributed_mode (line 216) | def init_distributed_mode(args):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (114K chars).
[
{
"path": ".gitignore",
"chars": 49,
"preview": ".idea\n.ipynb_checkpoints/\nDatasets/\n__pycache__/\n"
},
{
"path": "LICENSE",
"chars": 1102,
"preview": "MIT License\n\nCopyright (c) 2024 chenziwenhaoshuai, tommarvoloriddle, chouheiwa\n\nPermission is hereby granted, free of ch"
},
{
"path": "README.md",
"chars": 4435,
"preview": "# Vision-KAN 🚀\n\nWelcome to **Vision-KAN**! We are exploring the exciting possibility of [KAN](https://github.com/KindXia"
},
{
"path": "__init__.py",
"chars": 156,
"preview": "# __init__.py\n\nfrom .models_kan import create_model\nfrom .engine import train_one_epoch, evaluate\n\n__all__ = ['create_mo"
},
{
"path": "augment.py",
"chars": 3456,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n\"\"\"\n3Augment implementation\nData-augmentati"
},
{
"path": "datasets.py",
"chars": 4441,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport os\nimport json\n\nfrom torchvision import datas"
},
{
"path": "ekan.py",
"chars": 10325,
"preview": "import torch\r\nimport torch.nn.functional as F\r\nimport math\r\n\r\n\r\nclass KANLinear(torch.nn.Module):\r\n def __init__(\r\n "
},
{
"path": "engine.py",
"chars": 4241,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nTrain and eval functions used in main.py\n\"\"\"\nimp"
},
{
"path": "fasterkan.py",
"chars": 5304,
"preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport math\r\nfrom typing import *\r\n\r\n\r\nclass Splin"
},
{
"path": "hubconf.py",
"chars": 227,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nfrom models import *\nfrom cait_models import *\nfrom "
},
{
"path": "kit.py",
"chars": 23182,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\r\n# All rights reserved.\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfro"
},
{
"path": "losses.py",
"chars": 3388,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nImplements the knowledge distillation loss\n\"\"\"\ni"
},
{
"path": "main.py",
"chars": 23450,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport argparse\nimport datetime\nimport numpy as np\ni"
},
{
"path": "minimal_example.py",
"chars": 2365,
"preview": "from models_kan import create_model\nimport torch.optim as optim\nimport torch\nimport torchvision\nimport torchvision.trans"
},
{
"path": "models_kan.py",
"chars": 13424,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport torch\nimport torch.nn as nn\nfrom functools im"
},
{
"path": "pyproject.toml",
"chars": 420,
"preview": "# pyproject.toml\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poet"
},
{
"path": "requirements.txt",
"chars": 45,
"preview": "torch==1.13.1\ntorchvision==0.8.1\ntimm==0.3.2\n"
},
{
"path": "samplers.py",
"chars": 2584,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport torch\nimport torch.distributed as dist\nimport"
},
{
"path": "utils.py",
"chars": 7086,
"preview": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nMisc functions, including distributed helpers.\n\n"
}
]
About this extraction
This page contains the full source code of the chenziwenhaoshuai/Vision-KAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (107.1 KB), approximately 27.2k tokens, and a symbol index with 140 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.