[
  {
    "path": ".gitignore",
    "content": ".idea\n.ipynb_checkpoints/\nDatasets/\n__pycache__/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 chenziwenhaoshuai, tommarvoloriddle, chouheiwa\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Vision-KAN 🚀\n\nWelcome to **Vision-KAN**! We are exploring the exciting possibility of [KAN](https://github.com/KindXiaoming/pykan) replacing MLP in Vision Transformer. Due to GPU resource constraints, this project may experience delays, but we'll keep you updated with any new developments here! 📅✨\n\n## Installation 🛠️\n\nTo install this package, simply run:\n\n```bash\npip install VisionKAN\n```\n\n## Minimal Example 💡\n\nHere's a quick example to get you started:\n\n```python\nfrom VisionKAN import create_model, train_one_epoch, evaluate\n\nKAN_model = create_model(\n    model_name='deit_tiny_patch16_224_KAN',\n    pretrained=False,\n    hdim_kan=192,\n    num_classes=100,\n    drop_rate=0.0,\n    drop_path_rate=0.05,\n    img_size=224,\n    batch_size=144\n)\n```\n\n## Performance Overview 📊\n\n### Baseline Models\n\n| Dataset | MLP Hidden Dim | Model               | Date | Epoch | Top-1 | Top-5 | Checkpoint |\n|---------|----------------|---------------------|------|-------|-------|-------|------------|\n| ImageNet 1k | 768          | DeiT-tiny (baseline) | -    | 300   | 72.2  | 91.1  | -          |\n| CIFAR-100   | 192          | DeiT-tiny (baseline) | 2024.5.25 | 300(stop) | 84.94 | 96.53 | [Checkpoint](https://drive.google.com/drive/folders/1hPrnfI5CKMgwM6lgSrFUwvMQYsjtjg3A?usp=drive_link) |\n| CIFAR-100   | 384          | DeiT-small (baseline) | 2024.5.25 | 300(stop) | 86.49 | 96.17 | [Checkpoint](https://drive.google.com/drive/folders/1ZSl2ojZUQRkIsZzJ0w5rahOTAv4IiZCt?usp=drive_link) |\n| CIFAR-100   | 768          | DeiT-base (baseline)  | 2024.5.25 | 300(stop) | 86.54 | 96.16 | [Checkpoint](https://drive.google.com/drive/folders/14kLdJDy11zv_mC35JvbcPCdoXvrHspNK?usp=sharing) |\n\n### Vision-KAN Models\n\n| Dataset | KAN Hidden Dim | Model     | Date     | Epoch     | Top-1 | Top-5 | Checkpoint |\n|---------|----------------|-----------|----------|-----------|-------|-------|------------|\n| ImageNet 1k | 20         | Vision-KAN | 2024.5.16 | 37(stop)  | 36.34 | 61.48 | -          |\n| ImageNet 1k | 192        | Vision-KAN | 2024.5.25 | 346(stop) | 64.87 | 86.14 | [Checkpoint](https://pan.baidu.com/s/117ox7oh6zzXLwPMmQ6od1Q?pwd=y1vw) |\n| ImageNet 1k | 768        | Vision-KAN | 2024.6.2  | 154(training) | 62.90 | 85.03 | -          |\n| CIFAR-100   | 192        | Vision-KAN | 2024.5.25 | 300(stop) | 73.17 | 93.307 | [Checkpoint](https://drive.google.com/drive/folders/19WPq6bZ9NgX-WxD7qXSTKiHc5D6P8jQP?usp=sharing) |\n| CIFAR-100   | 384        | Vision-KAN | 2024.5.25 | 300(stop) | 78.69 | 94.73 | [Checkpoint](https://drive.google.com/drive/folders/1Uhj4yV0HZRQkPFUerxy88B19N1eDdgsc?usp=drive_link) |\n| CIFAR-100   | 768        | Vision-KAN | 2024.5.29 | 300(stop) | 79.82 | 95.42 | [Checkpoint](https://drive.google.com/drive/folders/1FT55_6tDO_a135sQKBDn409fDdXvCi4N?usp=drive_link) |\n\n## Latest News 📰\n\n- **5.7.2024**: Released the current Vision KAN code! 🚀 We used efficient KAN to replace the MLP layer in the Transformer block and are pre-training the Tiny model on ImageNet 1k. Updates will be reflected in the table.\n- **5.14.2024**: The model is starting to converge! We’re using [192, 20, 192] for input, hidden, and output dimensions.\n- **5.15.2024**: Switched from [efficient kan](https://github.com/Blealtan/efficient-kan) to [faster kan](https://github.com/AthanasiosDelis/faster-kan) to double the training speed! 🚀\n- **5.16.2024**: Convergence appears to be bottlenecked; considering adjusting the KAN hidden layer size from 20 to 192.\n- **5.22.2024**: Fixed Timm version dependency issues and cleaned up the code! 🧹\n- **5.24.2024**: Loss decline is slowing, nearing final results! 🔍\n- **5.25.2024**: The model with 192 hidden layers is approaching convergence! 🎉 Released the best checkpoint of VisionKAN.\n\n## Architecture 🏗️\n\nWe utilized [DeiT](https://github.com/facebookresearch/deit) as the baseline for Vision KAN development. Huge thanks to Meta and MIT for their incredible work! 🙌\n\n## Star History 🌟\n\n[![Star History Chart](https://api.star-history.com/svg?repos=chenziwenhaoshuai/Vision-KAN&type=Date)](https://star-history.com/#chenziwenhaoshuai/Vision-KAN&Date)\n\n## Citation 📑\n\nIf you are using our work, please cite:\n\n```bibtex\n@misc{VisionKAN2024,\n  author = {Ziwen Chen and Gundavarapu and WU DI},\n  title = {Vision-KAN: Exploring the Possibility of KAN Replacing MLP in Vision Transformer},\n  year = {2024},\n  howpublished = {\\url{https://github.com/chenziwenhaoshuai/Vision-KAN.git}},\n}\n```\n"
  },
  {
    "path": "__init__.py",
    "content": "# __init__.py\n\nfrom .models_kan import create_model\nfrom .engine import train_one_epoch, evaluate\n\n__all__ = ['create_model', 'train_one_epoch', 'evaluate']"
  },
  {
    "path": "augment.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n\"\"\"\n3Augment implementation\nData-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)\nand timm DA(https://github.com/rwightman/pytorch-image-models)\n\"\"\"\nimport torch\nfrom torchvision import transforms\n\nfrom timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor\n\nimport numpy as np\nfrom torchvision import datasets, transforms\nimport random\n\n\n\nfrom PIL import ImageFilter, ImageOps\nimport torchvision.transforms.functional as TF\n\n\nclass GaussianBlur(object):\n    \"\"\"\n    Apply Gaussian Blur to the PIL image.\n    \"\"\"\n    def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):\n        self.prob = p\n        self.radius_min = radius_min\n        self.radius_max = radius_max\n\n    def __call__(self, img):\n        do_it = random.random() <= self.prob\n        if not do_it:\n            return img\n\n        img = img.filter(\n            ImageFilter.GaussianBlur(\n                radius=random.uniform(self.radius_min, self.radius_max)\n            )\n        )\n        return img\n\nclass Solarization(object):\n    \"\"\"\n    Apply Solarization to the PIL image.\n    \"\"\"\n    def __init__(self, p=0.2):\n        self.p = p\n\n    def __call__(self, img):\n        if random.random() < self.p:\n            return ImageOps.solarize(img)\n        else:\n            return img\n\nclass gray_scale(object):\n    \"\"\"\n    Apply Solarization to the PIL image.\n    \"\"\"\n    def __init__(self, p=0.2):\n        self.p = p\n        self.transf = transforms.Grayscale(3)\n \n    def __call__(self, img):\n        if random.random() < self.p:\n            return self.transf(img)\n        else:\n            return img\n \n    \n    \nclass horizontal_flip(object):\n    \"\"\"\n    Apply Solarization to the PIL image.\n    \"\"\"\n    def __init__(self, p=0.2,activate_pred=False):\n        self.p = p\n        self.transf = transforms.RandomHorizontalFlip(p=1.0)\n \n    def __call__(self, img):\n        if random.random() < self.p:\n            return self.transf(img)\n        else:\n            return img\n        \n    \n    \ndef new_data_aug_generator(args = None):\n    img_size = args.input_size\n    remove_random_resized_crop = args.src\n    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]\n    primary_tfl = []\n    scale=(0.08, 1.0)\n    interpolation='bicubic'\n    if remove_random_resized_crop:\n        primary_tfl = [\n            transforms.Resize(img_size, interpolation=3),\n            transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),\n            transforms.RandomHorizontalFlip()\n        ]\n    else:\n        primary_tfl = [\n            RandomResizedCropAndInterpolation(\n                img_size, scale=scale, interpolation=interpolation),\n            transforms.RandomHorizontalFlip()\n        ]\n\n        \n    secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),\n                                              Solarization(p=1.0),\n                                              GaussianBlur(p=1.0)])]\n   \n    if args.color_jitter is not None and not args.color_jitter==0:\n        secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))\n    final_tfl = [\n            transforms.ToTensor(),\n            transforms.Normalize(\n                mean=torch.tensor(mean),\n                std=torch.tensor(std))\n        ]\n    return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)\n"
  },
  {
    "path": "datasets.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport os\nimport json\n\nfrom torchvision import datasets, transforms\nfrom torchvision.datasets.folder import ImageFolder, default_loader\nfrom torchvision.transforms import InterpolationMode\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom timm.data import create_transform\n\n\nclass INatDataset(ImageFolder):\n    def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,\n                 category='name', loader=default_loader):\n        self.transform = transform\n        self.loader = loader\n        self.target_transform = target_transform\n        self.year = year\n        # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']\n        path_json = os.path.join(root, f'{\"train\" if train else \"val\"}{year}.json')\n        with open(path_json) as json_file:\n            data = json.load(json_file)\n\n        with open(os.path.join(root, 'categories.json')) as json_file:\n            data_catg = json.load(json_file)\n\n        path_json_for_targeter = os.path.join(root, f\"train{year}.json\")\n\n        with open(path_json_for_targeter) as json_file:\n            data_for_targeter = json.load(json_file)\n\n        targeter = {}\n        indexer = 0\n        for elem in data_for_targeter['annotations']:\n            king = []\n            king.append(data_catg[int(elem['category_id'])][category])\n            if king[0] not in targeter.keys():\n                targeter[king[0]] = indexer\n                indexer += 1\n        self.nb_classes = len(targeter)\n\n        self.samples = []\n        for elem in data['images']:\n            cut = elem['file_name'].split('/')\n            target_current = int(cut[2])\n            path_current = os.path.join(root, cut[0], cut[2], cut[3])\n\n            categors = data_catg[target_current]\n            target_current_true = targeter[categors[category]]\n            self.samples.append((path_current, target_current_true))\n\n    # __getitem__ and __len__ inherited from ImageFolder\n\n\ndef build_dataset(is_train, args):\n    transform = build_transform(is_train, args)\n\n    if args.data_set == 'CIFAR':\n        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)\n        nb_classes = 100\n    elif args.data_set == 'IMNET':\n        root = os.path.join(args.data_path, 'train' if is_train else 'val')\n        dataset = datasets.ImageFolder(root, transform=transform)\n        nb_classes = 1000\n    elif args.data_set == 'TinyIMNET':\n        root = os.path.join(args.data_path, 'tiny-imagenet-200/train' if is_train else 'tiny-imagenet-200/val')\n        dataset = datasets.ImageFolder(root, transform=transform)\n        nb_classes = 100\n    elif args.data_set == 'INAT':\n        dataset = INatDataset(args.data_path, train=is_train, year=2018,\n                              category=args.inat_category, transform=transform)\n        nb_classes = dataset.nb_classes\n    elif args.data_set == 'INAT19':\n        dataset = INatDataset(args.data_path, train=is_train, year=2019,\n                              category=args.inat_category, transform=transform)\n        nb_classes = dataset.nb_classes\n\n    return dataset, nb_classes\n\n\ndef build_transform(is_train, args):\n    resize_im = args.input_size > 32\n    if is_train:\n        # this should always dispatch to transforms_imagenet_train\n        transform = create_transform(\n            input_size=args.input_size,\n            is_training=True,\n            color_jitter=args.color_jitter,\n            auto_augment=args.aa,\n            interpolation=args.train_interpolation,\n            re_prob=args.reprob,\n            re_mode=args.remode,\n            re_count=args.recount,\n        )\n        if not resize_im:\n            # replace RandomResizedCropAndInterpolation with\n            # RandomCrop\n            transform.transforms[0] = transforms.RandomCrop(\n                args.input_size, padding=4)\n        return transform\n\n    t = []\n    if resize_im:\n        size = int(args.input_size / args.eval_crop_ratio)\n        t.append(\n            transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),  # to maintain same ratio w.r.t. 224 images\n        )\n        t.append(transforms.CenterCrop(args.input_size))\n\n    t.append(transforms.ToTensor())\n    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))\n    return transforms.Compose(t)\n"
  },
  {
    "path": "ekan.py",
    "content": "import torch\r\nimport torch.nn.functional as F\r\nimport math\r\n\r\n\r\nclass KANLinear(torch.nn.Module):\r\n    def __init__(\r\n        self,\r\n        in_features,\r\n        out_features,\r\n        grid_size=5,\r\n        spline_order=3,\r\n        scale_noise=0.1,\r\n        scale_base=1.0,\r\n        scale_spline=1.0,\r\n        enable_standalone_scale_spline=False,\r\n        base_activation=torch.nn.SiLU,\r\n        grid_eps=0.02,\r\n        grid_range=[-1, 1],\r\n    ):\r\n        super(KANLinear, self).__init__()\r\n        self.in_features = in_features\r\n        self.out_features = out_features\r\n        self.grid_size = grid_size\r\n        self.spline_order = spline_order\r\n\r\n        h = (grid_range[1] - grid_range[0]) / grid_size\r\n        grid = (\r\n            (\r\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\r\n                + grid_range[0]\r\n            )\r\n            .expand(in_features, -1)\r\n            .contiguous()\r\n        )\r\n        self.register_buffer(\"grid\", grid)\r\n\r\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\r\n        self.spline_weight = torch.nn.Parameter(\r\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\r\n        )\r\n        if enable_standalone_scale_spline:\r\n            self.spline_scaler = torch.nn.Parameter(\r\n                torch.Tensor(out_features, in_features)\r\n            )\r\n\r\n        self.scale_noise = scale_noise\r\n        self.scale_base = scale_base\r\n        self.scale_spline = scale_spline\r\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\r\n        self.base_activation = base_activation()\r\n        self.grid_eps = grid_eps\r\n\r\n        self.reset_parameters()\r\n\r\n    def reset_parameters(self):\r\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\r\n        with torch.no_grad():\r\n            noise = (\r\n                (\r\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\r\n                    - 1 / 2\r\n                )\r\n                * self.scale_noise\r\n                / self.grid_size\r\n            )\r\n            self.spline_weight.data.copy_(\r\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\r\n                * self.curve2coeff(\r\n                    self.grid.T[self.spline_order : -self.spline_order],\r\n                    noise,\r\n                )\r\n            )\r\n            if self.enable_standalone_scale_spline:\r\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\r\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\r\n\r\n    def b_splines(self, x: torch.Tensor):\r\n        \"\"\"\r\n        Compute the B-spline bases for the given input tensor.\r\n\r\n        Args:\r\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\r\n\r\n        Returns:\r\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\r\n        \"\"\"\r\n        assert x.dim() == 2 and x.size(1) == self.in_features\r\n\r\n        grid: torch.Tensor = (\r\n            self.grid\r\n        )  # (in_features, grid_size + 2 * spline_order + 1)\r\n        x = x.unsqueeze(-1)\r\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\r\n        for k in range(1, self.spline_order + 1):\r\n            bases = (\r\n                (x - grid[:, : -(k + 1)])\r\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\r\n                * bases[:, :, :-1]\r\n            ) + (\r\n                (grid[:, k + 1 :] - x)\r\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\r\n                * bases[:, :, 1:]\r\n            )\r\n\r\n        assert bases.size() == (\r\n            x.size(0),\r\n            self.in_features,\r\n            self.grid_size + self.spline_order,\r\n        )\r\n        return bases.contiguous()\r\n\r\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\r\n        \"\"\"\r\n        Compute the coefficients of the curve that interpolates the given points.\r\n\r\n        Args:\r\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\r\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\r\n\r\n        Returns:\r\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\r\n        \"\"\"\r\n        assert x.dim() == 2 and x.size(1) == self.in_features\r\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\r\n\r\n        A = self.b_splines(x).transpose(\r\n            0, 1\r\n        )  # (in_features, batch_size, grid_size + spline_order)\r\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\r\n        solution = torch.linalg.lstsq(\r\n            A, B\r\n        ).solution  # (in_features, grid_size + spline_order, out_features)\r\n        result = solution.permute(\r\n            2, 0, 1\r\n        )  # (out_features, in_features, grid_size + spline_order)\r\n\r\n        assert result.size() == (\r\n            self.out_features,\r\n            self.in_features,\r\n            self.grid_size + self.spline_order,\r\n        )\r\n        return result.contiguous()\r\n\r\n    @property\r\n    def scaled_spline_weight(self):\r\n        return self.spline_weight * (\r\n            self.spline_scaler.unsqueeze(-1)\r\n            if self.enable_standalone_scale_spline\r\n            else 1.0\r\n        )\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        assert x.dim() == 2 and x.size(1) == self.in_features\r\n\r\n        base_output = F.linear(self.base_activation(x), self.base_weight)\r\n        spline_output = F.linear(\r\n            self.b_splines(x).view(x.size(0), -1),\r\n            self.scaled_spline_weight.view(self.out_features, -1),\r\n        )\r\n        return base_output + spline_output\r\n\r\n    @torch.no_grad()\r\n    def update_grid(self, x: torch.Tensor, margin=0.01):\r\n        assert x.dim() == 2 and x.size(1) == self.in_features\r\n        batch = x.size(0)\r\n\r\n        splines = self.b_splines(x)  # (batch, in, coeff)\r\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\r\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\r\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\r\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\r\n        unreduced_spline_output = unreduced_spline_output.permute(\r\n            1, 0, 2\r\n        )  # (batch, in, out)\r\n\r\n        # sort each channel individually to collect data distribution\r\n        x_sorted = torch.sort(x, dim=0)[0]\r\n        grid_adaptive = x_sorted[\r\n            torch.linspace(\r\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\r\n            )\r\n        ]\r\n\r\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\r\n        grid_uniform = (\r\n            torch.arange(\r\n                self.grid_size + 1, dtype=torch.float32, device=x.device\r\n            ).unsqueeze(1)\r\n            * uniform_step\r\n            + x_sorted[0]\r\n            - margin\r\n        )\r\n\r\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\r\n        grid = torch.concatenate(\r\n            [\r\n                grid[:1]\r\n                - uniform_step\r\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\r\n                grid,\r\n                grid[-1:]\r\n                + uniform_step\r\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\r\n            ],\r\n            dim=0,\r\n        )\r\n\r\n        self.grid.copy_(grid.T)\r\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\r\n\r\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\r\n        \"\"\"\r\n        Compute the regularization loss.\r\n\r\n        This is a dumb simulation of the original L1 regularization as stated in the\r\n        paper, since the original one requires computing absolutes and entropy from the\r\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\r\n        behind the F.linear function if we want an memory efficient implementation.\r\n\r\n        The L1 regularization is now computed as mean absolute value of the spline\r\n        weights. The authors implementation also includes this term in addition to the\r\n        sample-based regularization.\r\n        \"\"\"\r\n        l1_fake = self.spline_weight.abs().mean(-1)\r\n        regularization_loss_activation = l1_fake.sum()\r\n        p = l1_fake / regularization_loss_activation\r\n        regularization_loss_entropy = -torch.sum(p * p.log())\r\n        return (\r\n            regularize_activation * regularization_loss_activation\r\n            + regularize_entropy * regularization_loss_entropy\r\n        )\r\n\r\n\r\nclass KAN(torch.nn.Module):\r\n    def __init__(\r\n        self,\r\n        layers_hidden,\r\n        grid_size=5,\r\n        spline_order=3,\r\n        scale_noise=0.1,\r\n        scale_base=1.0,\r\n        scale_spline=1.0,\r\n        base_activation=torch.nn.SiLU,\r\n        grid_eps=0.02,\r\n        grid_range=[-1, 1],\r\n    ):\r\n        super(KAN, self).__init__()\r\n        self.grid_size = grid_size\r\n        self.spline_order = spline_order\r\n\r\n        self.layers = torch.nn.ModuleList()\r\n        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\r\n            self.layers.append(\r\n                KANLinear(\r\n                    in_features,\r\n                    out_features,\r\n                    grid_size=grid_size,\r\n                    spline_order=spline_order,\r\n                    scale_noise=scale_noise,\r\n                    scale_base=scale_base,\r\n                    scale_spline=scale_spline,\r\n                    base_activation=base_activation,\r\n                    grid_eps=grid_eps,\r\n                    grid_range=grid_range,\r\n                )\r\n            )\r\n\r\n    def forward(self, x: torch.Tensor, update_grid=False):\r\n        for layer in self.layers:\r\n            if update_grid:\r\n                layer.update_grid(x)\r\n            x = layer(x)\r\n        return x\r\n\r\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\r\n        return sum(\r\n            layer.regularization_loss(regularize_activation, regularize_entropy)\r\n            for layer in self.layers\r\n        )"
  },
  {
    "path": "engine.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nTrain and eval functions used in main.py\n\"\"\"\nimport math\nimport sys\nfrom typing import Iterable, Optional\n\nimport torch\n\nfrom timm.data import Mixup\nfrom timm.utils import accuracy, ModelEma\n\nfrom losses import DistillationLoss\nimport utils\n\n\ndef train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,\n                    data_loader: Iterable, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,\n                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,\n                    set_training_mode=True, args = None):\n    model.train(set_training_mode)\n    metric_logger = utils.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 50\n    \n    if args.cosub:\n        criterion = torch.nn.BCEWithLogitsLoss()\n        \n    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):\n        samples = samples.to(device, non_blocking=True)\n        targets = targets.to(device, non_blocking=True)\n\n        if mixup_fn is not None:\n            samples, targets = mixup_fn(samples, targets)\n            \n        if args.cosub:\n            samples = torch.cat((samples,samples),dim=0)\n            \n        if args.bce_loss:\n            targets = targets.gt(0.0).type(targets.dtype)\n         \n        with torch.cuda.amp.autocast():\n            outputs = model(samples)\n            if not args.cosub:\n                loss = criterion(samples, outputs, targets)\n            else:\n                outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)\n                loss = 0.25 * criterion(outputs[0], targets) \n                loss = loss + 0.25 * criterion(outputs[1], targets) \n                loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())\n                loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) \n\n        loss_value = loss.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        optimizer.zero_grad()\n\n        # this attribute is added by timm on one optimizer (adahessian)\n        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order\n        loss_scaler(loss, optimizer, clip_grad=max_norm,\n                    parameters=model.parameters(), create_graph=is_second_order)\n\n        torch.cuda.synchronize()\n        if model_ema is not None:\n            model_ema.update(model)\n\n        metric_logger.update(loss=loss_value)\n        metric_logger.update(lr=optimizer.param_groups[0][\"lr\"])\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n\n\n@torch.no_grad()\ndef evaluate(data_loader, model, device):\n    criterion = torch.nn.CrossEntropyLoss()\n\n    metric_logger = utils.MetricLogger(delimiter=\"  \")\n    header = 'Test:'\n\n    # switch to evaluation mode\n    model.eval()\n\n    for images, target in metric_logger.log_every(data_loader, 10, header):\n        images = images.to(device, non_blocking=True)\n        target = target.to(device, non_blocking=True)\n\n        # compute output\n        with torch.cuda.amp.autocast():\n            output = model(images)\n            loss = criterion(output, target)\n\n        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n\n        batch_size = images.shape[0]\n        metric_logger.update(loss=loss.item())\n        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)\n        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'\n          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))\n\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n"
  },
  {
    "path": "fasterkan.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport math\r\nfrom typing import *\r\n\r\n\r\nclass SplineLinear(nn.Linear):\r\n    def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:\r\n        self.init_scale = init_scale\r\n        super().__init__(in_features, out_features, bias=False, **kw)\r\n\r\n    def reset_parameters(self) -> None:\r\n        nn.init.xavier_uniform_(self.weight)  # Using Xavier Uniform initialization\r\n\r\n\r\nclass ReflectionalSwitchFunction(nn.Module):\r\n    def __init__(\r\n            self,\r\n            grid_min: float = -2.,\r\n            grid_max: float = 2.,\r\n            num_grids: int = 8,\r\n            exponent: int = 2,\r\n            denominator: float = 0.33,  # larger denominators lead to smoother basis\r\n    ):\r\n        super().__init__()\r\n        grid = torch.linspace(grid_min, grid_max, num_grids)\r\n        self.grid = torch.nn.Parameter(grid, requires_grad=False)\r\n        self.denominator = denominator  # or (grid_max - grid_min) / (num_grids - 1)\r\n        # self.exponent = exponent\r\n        self.inv_denominator = 1 / self.denominator  # Cache the inverse of the denominator\r\n\r\n    def forward(self, x):\r\n        diff = (x[..., None] - self.grid)\r\n        diff_mul = diff.mul(self.inv_denominator)\r\n        diff_tanh = torch.tanh(diff_mul)\r\n        diff_pow = -diff_tanh.mul(diff_tanh)\r\n        diff_pow += 1\r\n        # diff_pow *= 0.667\r\n        return diff_pow  # Replace pow with multiplication for squaring\r\n\r\n\r\nclass FasterKANLayer(nn.Module):\r\n    def __init__(\r\n            self,\r\n            input_dim: int,\r\n            output_dim: int,\r\n            grid_min: float = -2.,\r\n            grid_max: float = 2.,\r\n            num_grids: int = 8,\r\n            exponent: int = 2,\r\n            denominator: float = 0.33,\r\n            use_base_update: bool = True,\r\n            base_activation=F.silu,\r\n            spline_weight_init_scale: float = 0.1,\r\n    ) -> None:\r\n        super().__init__()\r\n        self.layernorm = nn.LayerNorm(input_dim)\r\n        self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, denominator)\r\n        self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)\r\n        # self.use_base_update = use_base_update\r\n        # if use_base_update:\r\n        #    self.base_activation = base_activation\r\n        #    self.base_linear = nn.Linear(input_dim, output_dim)\r\n\r\n    def forward(self, x, time_benchmark=False):\r\n        if not time_benchmark:\r\n            spline_basis = self.rbf(self.layernorm(x)).view(x.shape[0], -1)\r\n            # print(\"spline_basis:\", spline_basis.shape)\r\n        else:\r\n            spline_basis = self.rbf(x).view(x.shape[0], -1)\r\n            # print(\"spline_basis:\", spline_basis.shape)\r\n        # print(\"-------------------------\")\r\n        # ret = 0\r\n        ret = self.spline_linear(spline_basis)\r\n        # print(\"spline_basis.shape[:-2]:\", spline_basis.shape[:-2])\r\n        # print(\"*spline_basis.shape[:-2]:\", *spline_basis.shape[:-2])\r\n        # print(\"spline_basis.view(*spline_basis.shape[:-2], -1):\", spline_basis.view(*spline_basis.shape[:-2], -1).shape)\r\n        # print(\"ret:\", ret.shape)\r\n        # print(\"-------------------------\")\r\n        # if self.use_base_update:\r\n        # base = self.base_linear(self.base_activation(x))\r\n        # print(\"self.base_activation(x):\", self.base_activation(x).shape)\r\n        # print(\"base:\", base.shape)\r\n        # print(\"@@@@@@@@@\")\r\n        # ret += base\r\n        return ret\r\n\r\n        # spline_basis = spline_basis.reshape(x.shape[0], -1)  # Reshape to [batch_size, input_dim * num_grids]\r\n        # print(\"spline_basis:\", spline_basis.shape)\r\n\r\n        # spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0])  # Reshape to [input_dim * num_grids, output_dim]\r\n        # print(\"spline_weight:\", spline_weight.shape)\r\n\r\n        # spline = torch.matmul(spline_basis, spline_weight)  # Resulting shape: [batch_size, output_dim]\r\n\r\n        # print(\"-------------------------\")\r\n        # print(\"Base shape:\", base.shape)\r\n        # print(\"Spline shape:\", spline.shape)\r\n        # print(\"@@@@@@@@@\")\r\n\r\n\r\nclass FasterKAN(nn.Module):\r\n    def __init__(\r\n            self,\r\n            layers_hidden: List[int],\r\n            grid_min: float = -2.,\r\n            grid_max: float = 2.,\r\n            num_grids: int = 8,\r\n            exponent: int = 2,\r\n            denominator: float = 0.33,\r\n            use_base_update: bool = True,\r\n            base_activation=F.silu,\r\n            spline_weight_init_scale: float = 0.667,\r\n    ) -> None:\r\n        super().__init__()\r\n        self.layers = nn.ModuleList([\r\n            FasterKANLayer(\r\n                in_dim, out_dim,\r\n                grid_min=grid_min,\r\n                grid_max=grid_max,\r\n                num_grids=num_grids,\r\n                exponent=exponent,\r\n                denominator=denominator,\r\n                use_base_update=use_base_update,\r\n                base_activation=base_activation,\r\n                spline_weight_init_scale=spline_weight_init_scale,\r\n            ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])\r\n        ])\r\n\r\n    def forward(self, x):\r\n        for layer in self.layers:\r\n            x = layer(x)\r\n        return x"
  },
  {
    "path": "hubconf.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nfrom models import *\nfrom cait_models import *\nfrom resmlp_models import *\n#from patchconvnet_models import *\n\ndependencies = [\"torch\", \"torchvision\", \"timm\"]\n"
  },
  {
    "path": "kit.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\r\n# All rights reserved.\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nfrom functools import partial\r\n\r\nfrom timm.models.vision_transformer import Mlp, PatchEmbed, _cfg\r\nfrom ekan import KAN\r\n\r\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\r\nfrom timm.models.registry import register_model\r\n__all__ = [\r\n    'deit_tiny_patch16_LS'\r\n]\r\n\r\nclass Attention(nn.Module):\r\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\r\n        super().__init__()\r\n        self.num_heads = num_heads\r\n        head_dim = dim // num_heads\r\n        self.scale = qk_scale or head_dim ** -0.5\r\n\r\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\r\n        self.attn_drop = nn.Dropout(attn_drop)\r\n        self.proj = nn.Linear(dim, dim)\r\n        self.proj_drop = nn.Dropout(proj_drop)\r\n\r\n    def forward(self, x):\r\n        B, N, C = x.shape\r\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\r\n        q, k, v = qkv[0], qkv[1], qkv[2]\r\n\r\n        q = q * self.scale\r\n\r\n        attn = (q @ k.transpose(-2, -1))\r\n        attn = attn.softmax(dim=-1)\r\n        attn = self.attn_drop(attn)\r\n\r\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\r\n        x = self.proj(x)\r\n        x = self.proj_drop(x)\r\n        return x\r\n\r\n\r\nclass Block(nn.Module):\r\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\r\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp\r\n                 , init_values=1e-4):\r\n        super().__init__()\r\n        self.norm1 = norm_layer(dim)\r\n        self.attn = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n        # self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n        self.kan = KAN([dim, 20, dim])\r\n\r\n    def forward(self, x):\r\n        b,t,d = x.shape\r\n        x = x + self.drop_path(self.attn(self.norm1(x)))\r\n        x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1,x.shape[-1])).reshape(b,t,d))\r\n        return x\r\n\r\n\r\nclass Layer_scale_init_Block(nn.Module):\r\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    # with slight modifications\r\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\r\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp\r\n                 , init_values=1e-4):\r\n        super().__init__()\r\n        self.norm1 = norm_layer(dim)\r\n        self.attn = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n        self.kan = KAN([dim, 20, dim])\r\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n\r\n    def forward(self, x):\r\n        b,t,d = x.shape\r\n        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\r\n        # x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\r\n        x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1,x.shape[-1])).reshape(b,t,d))\r\n        return x\r\n\r\n\r\nclass Layer_scale_init_Block_paralx2(nn.Module):\r\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    # with slight modifications\r\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\r\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp\r\n                 , init_values=1e-4):\r\n        super().__init__()\r\n        self.norm1 = norm_layer(dim)\r\n        self.norm11 = norm_layer(dim)\r\n        self.attn = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\r\n        self.attn1 = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        self.norm21 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n        self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n        self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n        self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)\r\n\r\n    def forward(self, x):\r\n        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + self.drop_path(\r\n            self.gamma_1_1 * self.attn1(self.norm11(x)))\r\n        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + self.drop_path(\r\n            self.gamma_2_1 * self.mlp1(self.norm21(x)))\r\n        return x\r\n\r\n\r\nclass Block_paralx2(nn.Module):\r\n    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    # with slight modifications\r\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\r\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention, Mlp_block=Mlp\r\n                 , init_values=1e-4):\r\n        super().__init__()\r\n        self.norm1 = norm_layer(dim)\r\n        self.norm11 = norm_layer(dim)\r\n        self.attn = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\r\n        self.attn1 = Attention_block(\r\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        self.norm21 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n        self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n\r\n    def forward(self, x):\r\n        x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))\r\n        x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))\r\n        return x\r\n\r\n\r\nclass hMLP_stem(nn.Module):\r\n    \"\"\" hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf\r\n    taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    with slight modifications\r\n    \"\"\"\r\n\r\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.SyncBatchNorm):\r\n        super().__init__()\r\n        img_size = to_2tuple(img_size)\r\n        patch_size = to_2tuple(patch_size)\r\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\r\n        self.img_size = img_size\r\n        self.patch_size = patch_size\r\n        self.num_patches = num_patches\r\n        self.proj = torch.nn.Sequential(*[nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),\r\n                                          norm_layer(embed_dim // 4),\r\n                                          nn.GELU(),\r\n                                          nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),\r\n                                          norm_layer(embed_dim // 4),\r\n                                          nn.GELU(),\r\n                                          nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),\r\n                                          norm_layer(embed_dim),\r\n                                          ])\r\n\r\n    def forward(self, x):\r\n        B, C, H, W = x.shape\r\n        x = self.proj(x).flatten(2).transpose(1, 2)\r\n        return x\r\n\r\n\r\nclass vit_models(nn.Module):\r\n    \"\"\" Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support\r\n    taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\r\n    with slight modifications\r\n    \"\"\"\r\n\r\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\r\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\r\n                 drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,\r\n                 block_layers=Block,\r\n                 Patch_layer=PatchEmbed, act_layer=nn.GELU,\r\n                 Attention_block=Attention, Mlp_block=Mlp,\r\n                 dpr_constant=True, init_scale=1e-4,\r\n                 mlp_ratio_clstk=4.0, **kwargs):\r\n        super().__init__()\r\n\r\n        self.dropout_rate = drop_rate\r\n\r\n        self.num_classes = num_classes\r\n        self.num_features = self.embed_dim = embed_dim\r\n\r\n        self.patch_embed = Patch_layer(\r\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\r\n        num_patches = self.patch_embed.num_patches\r\n\r\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\r\n\r\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\r\n\r\n        dpr = [drop_path_rate for i in range(depth)]\r\n        self.blocks = nn.ModuleList([\r\n            block_layers(\r\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\r\n                drop=0.0, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,\r\n                act_layer=act_layer, Attention_block=Attention_block, Mlp_block=Mlp_block, init_values=init_scale)\r\n            for i in range(depth)])\r\n\r\n        self.norm = norm_layer(embed_dim)\r\n\r\n        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]\r\n        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()\r\n\r\n        trunc_normal_(self.pos_embed, std=.02)\r\n        trunc_normal_(self.cls_token, std=.02)\r\n        self.apply(self._init_weights)\r\n\r\n    def _init_weights(self, m):\r\n        if isinstance(m, nn.Linear):\r\n            trunc_normal_(m.weight, std=.02)\r\n            if isinstance(m, nn.Linear) and m.bias is not None:\r\n                nn.init.constant_(m.bias, 0)\r\n        elif isinstance(m, nn.LayerNorm):\r\n            nn.init.constant_(m.bias, 0)\r\n            nn.init.constant_(m.weight, 1.0)\r\n\r\n    @torch.jit.ignore\r\n    def no_weight_decay(self):\r\n        return {'pos_embed', 'cls_token'}\r\n\r\n    def get_classifier(self):\r\n        return self.head\r\n\r\n    def get_num_layers(self):\r\n        return len(self.blocks)\r\n\r\n    def reset_classifier(self, num_classes, global_pool=''):\r\n        self.num_classes = num_classes\r\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\r\n\r\n    def forward_features(self, x):\r\n        B = x.shape[0]\r\n        x = self.patch_embed(x)\r\n\r\n        cls_tokens = self.cls_token.expand(B, -1, -1)\r\n\r\n        x = x + self.pos_embed\r\n\r\n        x = torch.cat((cls_tokens, x), dim=1)\r\n\r\n        for i, blk in enumerate(self.blocks):\r\n            x = blk(x)\r\n\r\n        x = self.norm(x)\r\n        return x[:, 0]\r\n\r\n    def forward(self, x):\r\n\r\n        x = self.forward_features(x)\r\n\r\n        if self.dropout_rate:\r\n            x = F.dropout(x, p=float(self.dropout_rate), training=self.training)\r\n        x = self.head(x)\r\n\r\n        return x\r\n\r\n\r\n# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)\r\n\r\n@register_model\r\ndef deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    model.default_cfg = _cfg()\r\n    if pretrained:\r\n        name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_' + str(img_size) + '_'\r\n        if pretrained_21k:\r\n            name += '21k.pth'\r\n        else:\r\n            name += '1k.pth'\r\n\r\n        checkpoint = torch.hub.load_state_dict_from_url(\r\n            url=name,\r\n            map_location=\"cpu\", check_hash=True\r\n        )\r\n        model.load_state_dict(checkpoint[\"model\"])\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    model.default_cfg = _cfg()\r\n    if pretrained:\r\n        name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_' + str(img_size) + '_'\r\n        if pretrained_21k:\r\n            name += '21k.pth'\r\n        else:\r\n            name += '1k.pth'\r\n\r\n        checkpoint = torch.hub.load_state_dict_from_url(\r\n            url=name,\r\n            map_location=\"cpu\", check_hash=True\r\n        )\r\n        model.load_state_dict(checkpoint[\"model\"])\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    if pretrained:\r\n        name = 'https://dl.fbaipublicfiles.com/deit/deit_3_base_' + str(img_size) + '_'\r\n        if pretrained_21k:\r\n            name += '21k.pth'\r\n        else:\r\n            name += '1k.pth'\r\n\r\n        checkpoint = torch.hub.load_state_dict_from_url(\r\n            url=name,\r\n            map_location=\"cpu\", check_hash=True\r\n        )\r\n        model.load_state_dict(checkpoint[\"model\"])\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    if pretrained:\r\n        name = 'https://dl.fbaipublicfiles.com/deit/deit_3_large_' + str(img_size) + '_'\r\n        if pretrained_21k:\r\n            name += '21k.pth'\r\n        else:\r\n            name += '1k.pth'\r\n\r\n        checkpoint = torch.hub.load_state_dict_from_url(\r\n            url=name,\r\n            map_location=\"cpu\", check_hash=True\r\n        )\r\n        model.load_state_dict(checkpoint[\"model\"])\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    if pretrained:\r\n        name = 'https://dl.fbaipublicfiles.com/deit/deit_3_huge_' + str(img_size) + '_'\r\n        if pretrained_21k:\r\n            name += '21k_v1.pth'\r\n        else:\r\n            name += '1k_v1.pth'\r\n\r\n        checkpoint = torch.hub.load_state_dict_from_url(\r\n            url=name,\r\n            map_location=\"cpu\", check_hash=True\r\n        )\r\n        model.load_state_dict(checkpoint[\"model\"])\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1280, depth=52, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1280, depth=26, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n    # model.default_cfg = _cfg()\r\n\r\n    return model\r\n\r\n\r\n# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)\r\n\r\n@register_model\r\ndef deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paralx2, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block_paralx2, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paralx2, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Layer_scale_init_Block, **kwargs)\r\n\r\n    return model\r\n\r\n\r\n@register_model\r\ndef deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):\r\n    model = vit_models(\r\n        img_size=img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,\r\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\r\n\r\n    return model\r\ndef create_model(model_name,**kwargs):\r\n    if model_name in __all__:\r\n        create_fn = globals()[model_name]\r\n        model = create_fn(**kwargs)\r\n        model.default_cfg = _cfg()\r\n        return model\r\n    else:\r\n        raise RuntimeError('Unknown model (%s)' % model_name)\r\nif __name__ == '__main__':\r\n    batch_size = 5\r\n    model = deit_tiny_patch16_LS().cuda()\r\n    print(model(torch.randn(5, 3, 224, 224).cuda()).shape)\r\n"
  },
  {
    "path": "losses.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nImplements the knowledge distillation loss\n\"\"\"\nimport torch\nfrom torch.nn import functional as F\n\n\nclass DistillationLoss(torch.nn.Module):\n    \"\"\"\n    This module wraps a standard criterion and adds an extra knowledge distillation loss by\n    taking a teacher model prediction and using it as additional supervision.\n    \"\"\"\n    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,\n                 distillation_type: str, alpha: float, tau: float):\n        super().__init__()\n        self.base_criterion = base_criterion\n        self.teacher_model = teacher_model\n        assert distillation_type in ['none', 'soft', 'hard']\n        self.distillation_type = distillation_type\n        self.alpha = alpha\n        self.tau = tau\n\n    def forward(self, inputs, outputs, labels):\n        \"\"\"\n        Args:\n            inputs: The original inputs that are feed to the teacher model\n            outputs: the outputs of the model to be trained. It is expected to be\n                either a Tensor, or a Tuple[Tensor, Tensor], with the original output\n                in the first position and the distillation predictions as the second output\n            labels: the labels for the base criterion\n        \"\"\"\n        outputs_kd = None\n        if not isinstance(outputs, torch.Tensor):\n            # assume that the model outputs a tuple of [outputs, outputs_kd]\n            outputs, outputs_kd = outputs\n        base_loss = self.base_criterion(outputs, labels)\n        if self.distillation_type == 'none':\n            return base_loss\n\n        if outputs_kd is None:\n            raise ValueError(\"When knowledge distillation is enabled, the model is \"\n                             \"expected to return a Tuple[Tensor, Tensor] with the output of the \"\n                             \"class_token and the dist_token\")\n        # don't backprop throught the teacher\n        with torch.no_grad():\n            teacher_outputs = self.teacher_model(inputs)\n\n        if self.distillation_type == 'soft':\n            T = self.tau\n            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100\n            # with slight modifications\n            distillation_loss = F.kl_div(\n                F.log_softmax(outputs_kd / T, dim=1),\n                #We provide the teacher's targets in log probability because we use log_target=True \n                #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)\n                #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.\n                F.log_softmax(teacher_outputs / T, dim=1),\n                reduction='sum',\n                log_target=True\n            ) * (T * T) / outputs_kd.numel()\n            #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. \n            #But we also experiments output_kd.size(0) \n            #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details\n        elif self.distillation_type == 'hard':\n            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))\n\n        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha\n        return loss\n"
  },
  {
    "path": "main.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport argparse\nimport datetime\nimport numpy as np\nimport time\nimport torch\nimport torch.backends.cudnn as cudnn\nimport json\n\nfrom pathlib import Path\n\nfrom timm.data import Mixup\n# from timm.models import create_model\nfrom models_kan import create_model\nfrom timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy\nfrom timm.scheduler import create_scheduler\nfrom timm.optim import create_optimizer\nfrom timm.utils import NativeScaler, get_state_dict, ModelEma\n\nfrom datasets import build_dataset\nfrom engine import train_one_epoch, evaluate\nfrom losses import DistillationLoss\nfrom samplers import RASampler\nfrom augment import new_data_aug_generator\n\n# import models\n# import models_v2\nimport kit\n\nimport utils\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)\n    parser.add_argument('--batch-size', default=144, type=int)\n    parser.add_argument('--epochs', default=300, type=int)\n    parser.add_argument('--bce-loss', action='store_true')\n    parser.add_argument('--unscale-lr', action='store_true',default=True)\n\n    # Model parameters\n    parser.add_argument('--model', default='deit_tiny_patch16_224', type=str, metavar='MODEL',\n                        help='Name of model to train')\n    parser.add_argument('--input-size', default=224, type=int, help='images input size')\n    parser.add_argument('--hdim_kan', default=192, type=int, help='hidden dimension for KAN')\n\n    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',\n                        help='Dropout rate (default: 0.)')\n    parser.add_argument('--drop-path', type=float, default=0.05, metavar='PCT',\n                        help='Drop path rate (default: 0.1)')\n\n    parser.add_argument('--model-ema', action='store_true')\n    parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')\n    parser.set_defaults(model_ema=True)\n    parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')\n    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')\n\n    # Optimizer parameters\n    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',\n                        help='Optimizer (default: \"adamw\"')\n    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',\n                        help='Optimizer Epsilon (default: 1e-8)')\n    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',\n                        help='Optimizer Betas (default: None, use opt default)')\n    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',\n                        help='SGD momentum (default: 0.9)')\n    parser.add_argument('--weight-decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n    # Learning rate schedule parameters\n    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',\n                        help='LR scheduler (default: \"cosine\"')\n    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',\n                        help='learning rate (default: 5e-4)')\n    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',\n                        help='learning rate noise on/off epoch percentages')\n    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',\n                        help='learning rate noise limit percent (default: 0.67)')\n    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',\n                        help='learning rate noise std-dev (default: 1.0)')\n    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',\n                        help='warmup learning rate (default: 1e-6)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')\n\n    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',\n                        help='epoch interval to decay LR')\n    parser.add_argument('--warmup-epochs', type=int, default=1, metavar='N',\n                        help='epochs to warmup LR, if scheduler supports')\n    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',\n                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')\n    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',\n                        help='patience epochs for Plateau LR scheduler (default: 10')\n    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',\n                        help='LR decay rate (default: 0.1)')\n\n    # Augmentation parameters\n    parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',\n                        help='Color jitter factor (default: 0.3)')\n    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". \" + \\\n                             \"(default: rand-m9-mstd0.5-inc1)'),\n    parser.add_argument('--smoothing', type=float, default=0., help='Label smoothing (default: 0.1)')\n    parser.add_argument('--train-interpolation', type=str, default='bicubic',\n                        help='Training interpolation (random, bilinear, bicubic default: \"bicubic\")')\n\n    parser.add_argument('--repeated-aug', action='store_true')\n    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')\n    parser.set_defaults(repeated_aug=True)\n    \n    parser.add_argument('--train-mode', action='store_true')\n    parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')\n    parser.set_defaults(train_mode=True)\n    \n    parser.add_argument('--ThreeAugment', default=True, action='store_true') #3augment\n\n    \n    parser.add_argument('--src', action='store_true') #simple random crop\n    \n    # * Random Erase params\n    parser.add_argument('--reprob', type=float, default=0., metavar='PCT',\n                        help='Random erase prob (default: 0.25)')\n    parser.add_argument('--remode', type=str, default='pixel',\n                        help='Random erase mode (default: \"pixel\")')\n    parser.add_argument('--recount', type=int, default=1,\n                        help='Random erase count (default: 1)')\n    parser.add_argument('--resplit', action='store_true', default=False,\n                        help='Do not random erase first (clean) augmentation split')\n\n    # * Mixup params\n    parser.add_argument('--mixup', type=float, default=0.8,\n                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')\n    parser.add_argument('--cutmix', type=float, default=1.0,\n                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')\n    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,\n                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')\n    parser.add_argument('--mixup-prob', type=float, default=1.0,\n                        help='Probability of performing mixup or cutmix when either/both is enabled')\n    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,\n                        help='Probability of switching to cutmix when both mixup and cutmix enabled')\n    parser.add_argument('--mixup-mode', type=str, default='batch',\n                        help='How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"')\n\n    # Distillation parameters\n    parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',\n                        help='Name of teacher model to train (default: \"regnety_160\"')\n    parser.add_argument('--teacher-path', type=str, default='')\n    parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help=\"\")\n    parser.add_argument('--distillation-alpha', default=0.5, type=float, help=\"\")\n    parser.add_argument('--distillation-tau', default=1.0, type=float, help=\"\")\n    \n    # * Cosub params\n    parser.add_argument('--cosub', action='store_true') \n    \n    # * Finetuning params\n    parser.add_argument('--finetune', default='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', help='finetune from checkpoint')\n    parser.add_argument('--attn-only', action='store_true') \n    \n    # Dataset parameters\n    parser.add_argument('--data-path', default='/scratch/sg7729/KAN/Vision-KAN/Datasets', type=str,\n                        help='dataset path')\n    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19','TinyIMNET'],\n                        type=str, help='Image Net dataset path')\n    parser.add_argument('--inat-category', default='name',\n                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],\n                        type=str, help='semantic granularity')\n\n    parser.add_argument('--output_dir', default='./output',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='', help='resume from checkpoint')\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')\n    parser.add_argument('--eval-crop-ratio', default=1, type=float, help=\"Crop ratio for evaluation\")\n    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin-mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',\n                        help='')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')\n    # parser.add_argument('--world_size', default=1, type=int,\n    #                     help='number of distributed processes')\n    # parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')\n    return parser\n\n\ndef main(args):\n    # utils.init_distributed_mode(args)\n\n    print(args)\n\n    if args.distillation_type != 'none' and args.finetune and not args.eval:\n        raise NotImplementedError(\"Finetuning with distillation not yet supported\")\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + utils.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    # random.seed(seed)\n\n    cudnn.benchmark = True\n\n    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)\n    dataset_val, _ = build_dataset(is_train=False, args=args)\n\n    # if args.distributed:\n    #     num_tasks = utils.get_world_size()\n    #     global_rank = utils.get_rank()\n    #     if args.repeated_aug:\n    #         sampler_train = RASampler(\n    #             dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n    #         )\n    #     else:\n    #         sampler_train = torch.utils.data.DistributedSampler(\n    #             dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n    #         )\n    #     if args.dist_eval:\n    #         if len(dataset_val) % num_tasks != 0:\n    #             print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '\n    #                   'This will slightly alter validation results as extra duplicate entries are added to achieve '\n    #                   'equal num of samples per-process.')\n    #         sampler_val = torch.utils.data.DistributedSampler(\n    #             dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)\n    #     else:\n    #         sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n    # else:\n    sampler_train = torch.utils.data.RandomSampler(dataset_train)\n    sampler_val = torch.utils.data.SequentialSampler(dataset_val)\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n    if args.ThreeAugment:\n        data_loader_train.dataset.transform = new_data_aug_generator(args)\n\n    data_loader_val = torch.utils.data.DataLoader(\n        dataset_val, sampler=sampler_val,\n        batch_size=int(1 * args.batch_size),\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True\n    )\n\n    mixup_fn = None\n    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None\n    if mixup_active:\n        mixup_fn = Mixup(\n            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,\n            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,\n            label_smoothing=args.smoothing, num_classes=args.nb_classes)\n\n    print(f\"Creating model: {args.model}\")\n    model = create_model(\n        args.model,\n        pretrained=False,\n        num_classes=args.nb_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        # drop_block_rate=None,\n        img_size=args.input_size,\n        batch_size=args.batch_size\n    )\n\n                    \n    if args.finetune:\n        if args.finetune.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.finetune, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.finetune, map_location='cpu')\n\n        checkpoint_model = checkpoint['model']\n        state_dict = model.state_dict()\n        for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:\n            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:\n                print(f\"Removing key {k} from pretrained checkpoint\")\n                del checkpoint_model[k]\n\n        # interpolate position embedding\n        pos_embed_checkpoint = checkpoint_model['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n        # only the position tokens are interpolated\n        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n        pos_tokens = torch.nn.functional.interpolate(\n            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n        checkpoint_model['pos_embed'] = new_pos_embed\n\n        model.load_state_dict(checkpoint_model, strict=False)\n        \n    if args.attn_only:\n        for name_p,p in model.named_parameters():\n            if '.attn.' in name_p:\n                p.requires_grad = True\n            else:\n                p.requires_grad = False\n        try:\n            model.head.weight.requires_grad = True\n            model.head.bias.requires_grad = True\n        except:\n            model.fc.weight.requires_grad = True\n            model.fc.bias.requires_grad = True\n        try:\n            model.pos_embed.requires_grad = True\n        except:\n            print('no position encoding')\n        try:\n            for p in model.patch_embed.parameters():\n                p.requires_grad = False\n        except:\n            print('no patch embed')\n            \n    model.to(device)\n\n    model_ema = None\n    if args.model_ema:\n        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper\n        model_ema = ModelEma(\n            model,\n            decay=args.model_ema_decay,\n            device='cpu' if args.model_ema_force_cpu else '',\n            resume='')\n\n    model_without_ddp = model\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])\n        model_without_ddp = model.module\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    print('number of params:', n_parameters)\n    if not args.unscale_lr:\n        linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0\n        args.lr = linear_scaled_lr\n    optimizer = create_optimizer(args, model_without_ddp)\n    loss_scaler = NativeScaler()\n\n    lr_scheduler, _ = create_scheduler(args, optimizer)\n\n    criterion = LabelSmoothingCrossEntropy()\n\n    if mixup_active:\n        # smoothing is handled with mixup label transform\n        criterion = SoftTargetCrossEntropy()\n    elif args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)\n    else:\n        criterion = torch.nn.CrossEntropyLoss()\n        \n    if args.bce_loss:\n        criterion = torch.nn.BCEWithLogitsLoss()\n        \n    teacher_model = None\n    if args.distillation_type != 'none':\n        assert args.teacher_path, 'need to specify teacher-path when using distillation'\n        print(f\"Creating teacher model: {args.teacher_model}\")\n        teacher_model = create_model(\n            args.teacher_model,\n            pretrained=False,\n            num_classes=args.nb_classes,\n            global_pool='avg',\n        )\n        if args.teacher_path.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.teacher_path, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.teacher_path, map_location='cpu')\n        teacher_model.load_state_dict(checkpoint['model'])\n        teacher_model.to(device)\n        teacher_model.eval()\n\n    # wrap the criterion in our custom DistillationLoss, which\n    # just dispatches to the original criterion if args.distillation_type is 'none'\n    criterion = DistillationLoss(\n        criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau\n    )\n\n    output_dir = Path(args.output_dir)\n    if args.resume:\n        if args.resume.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.resume, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.resume, map_location='cpu')\n        model_without_ddp.load_state_dict(checkpoint['model'])\n        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:\n            optimizer.load_state_dict(checkpoint['optimizer'])\n            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n            args.start_epoch = checkpoint['epoch'] + 1\n            if args.model_ema:\n                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])\n            if 'scaler' in checkpoint:\n                loss_scaler.load_state_dict(checkpoint['scaler'])\n        lr_scheduler.step(args.start_epoch)\n    if args.eval:\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        return\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    max_accuracy = 0.0\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n\n        train_stats = train_one_epoch(\n            model, criterion, data_loader_train,\n            optimizer, device, epoch, loss_scaler,\n            args.clip_grad, model_ema, mixup_fn,\n            set_training_mode=args.train_mode,  # keep in eval mode for deit finetuning / train mode for training and deit III finetuning\n            args = args,\n        )\n\n        lr_scheduler.step(epoch)\n        if args.output_dir:\n            checkpoint_paths = [output_dir / 'checkpoint.pth']\n            for checkpoint_path in checkpoint_paths:\n                utils.save_on_master({\n                    'model': model_without_ddp.state_dict(),\n                    'optimizer': optimizer.state_dict(),\n                    'lr_scheduler': lr_scheduler.state_dict(),\n                    'epoch': epoch,\n                    'model_ema': get_state_dict(model_ema),\n                    'scaler': loss_scaler.state_dict(),\n                    'args': args,\n                }, checkpoint_path)\n             \n\n        test_stats = evaluate(data_loader_val, model, device)\n        print(f\"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%\")\n        \n        if max_accuracy < test_stats[\"acc1\"]:\n            max_accuracy = test_stats[\"acc1\"]\n            if args.output_dir:\n                checkpoint_paths = [output_dir / 'best_checkpoint.pth']\n                for checkpoint_path in checkpoint_paths:\n                    utils.save_on_master({\n                        'model': model_without_ddp.state_dict(),\n                        'optimizer': optimizer.state_dict(),\n                        'lr_scheduler': lr_scheduler.state_dict(),\n                        'epoch': epoch,\n                        'model_ema': get_state_dict(model_ema),\n                        'scaler': loss_scaler.state_dict(),\n                        'args': args,\n                    }, checkpoint_path)\n            \n        print(f'Max accuracy: {max_accuracy:.2f}%')\n\n        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                     **{f'test_{k}': v for k, v in test_stats.items()},\n                     'epoch': epoch,\n                     'n_parameters': n_parameters}\n        \n        \n        \n        \n        if args.output_dir and utils.is_main_process():\n            with (output_dir / \"log.txt\").open(\"a\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])\n    args = parser.parse_args()\n    if args.output_dir:\n        Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n    main(args)\n# gs)\n"
  },
  {
    "path": "minimal_example.py",
    "content": "from models_kan import create_model\nimport torch.optim as optim\nimport torch\nimport torchvision\nimport torchvision.transforms as transforms\nimport torch.nn as nn\nfrom engine import train_one_epoch, evaluate\n\nKAN_model = create_model(\n    model_name='deit_tiny_patch16_224_KAN',\n    pretrained=False,\n    hdim_kan=192,\n    num_classes=10,\n    drop_rate=0.0,\n    drop_path_rate=0.05,\n    img_size=32,\n    batch_size=144\n)\n\n# dataset CIFAR10\n\ntransform = transforms.Compose(\n    [transforms.ToTensor(),\n     transforms.Normalize((0.5,), (0.5,))])\n\ntrainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n                                        download=True, transform=transform)\ntrainloader = torch.utils.data.DataLoader(trainset, batch_size=144,\n                                          shuffle=True, num_workers=2)\n\ntestset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\ntestloader = torch.utils.data.DataLoader(testset, batch_size=144, shuffle=False, num_workers=2)\n\nclasses = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')\n\n#optimizer\noptimizer = optim.SGD(KAN_model.parameters(), lr=0.001, momentum=0.9)\ncriterion = torch.nn.CrossEntropyLoss()\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nKAN_model.to(device)\n\n#train using engine.py\n\nfor epoch in range(2):  # loop over the dataset multiple times\n        running_loss = 0.0\n        for i, data in enumerate(trainloader, 0):\n            # get the inputs; data is a list of [inputs, labels]\n            inputs, labels = data[0].to(device), data[1].to(device)\n    \n            # zero the parameter gradients\n            optimizer.zero_grad()\n    \n            # forward + backward + optimize\n            outputs = KAN_model(inputs)\n            loss = criterion(outputs, labels)\n            loss.backward()\n            optimizer.step()\n    \n            # print statistics\n            running_loss += loss.item()\n            if i % 2000 == 1999:    # print every 2000 mini-batches\n                print('[%d, %5d] loss: %.3f' %\n                    (epoch + 1, i + 1, running_loss / 2000))\n                running_loss = 0.0\n        \n# evaluate\ntest_stats = evaluate(testloader, KAN_model, device=device)\nprint(f\"Accuracy of the network on the {len(testset)} test images: {test_stats['acc1']:.1f}%\")\n\nprint('Finished Training')"
  },
  {
    "path": "models_kan.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport torch\nimport torch.nn as nn\nfrom functools import partial\n\nfrom timm.models.vision_transformer import VisionTransformer, _cfg, Block, Attention\nfrom timm.models.registry import register_model\nfrom timm.models.layers import trunc_normal_\nfrom fasterkan import FasterKAN as KAN\n\n__all__KAN = [\n    'deit_base_patch16_224_KAN', 'deit_small_patch16_224_KAN',  \n    'deit_base_patch16_384_KAN', 'deit_tiny_patch16_224_KAN', \n    'deit_tiny_distilled_patch16_224_KAN', 'deit_base_distilled_patch16_224_KAN', \n    'deit_small_distilled_patch16_224_KAN', 'deit_base_distilled_patch16_384_KAN']\n\n\n__all__ViT = [\n    'deit_base_patch16_224_ViT', 'deit_small_patch16_224_ViT',  \n    'deit_base_patch16_384_ViT', 'deit_tiny_patch16_224_ViT', \n    'deit_tiny_distilled_patch16_224_ViT', 'deit_base_distilled_patch16_224_ViT', \n    'deit_small_distilled_patch16_224_ViT', 'deit_base_distilled_patch16_384_ViT']\n\n\nclass kanBlock(Block):\n\n    def __init__(self, dim, num_heads=8, hdim_kan=192, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__(dim, num_heads)\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.kan = KAN([dim, hdim_kan, dim])\n\n    def forward(self, x):\n        b, t, d = x.shape\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.kan(self.norm2(x).reshape(-1, x.shape[-1])).reshape(b, t, d))\n        return x\n\n\nclass VisionKAN(VisionTransformer):\n    def __init__(self, *args, num_heads=8, batch_size=16, **kwargs):\n        if 'hdim_kan' in kwargs:\n            self.hdim_kan = kwargs['hdim_kan']\n            del kwargs['hdim_kan']\n        else:\n            self.hdim_kan = 192\n        \n        super().__init__(*args, **kwargs)\n        self.num_heads = num_heads\n        # For newer version timm they don't save the depth to self.depth, so we need to check it\n        try:\n            self.depth\n        except AttributeError:\n            if 'depth' in kwargs:\n                self.depth = kwargs['depth']\n            else:\n                self.depth = 12\n\n        block_list = [\n            kanBlock(dim=self.embed_dim, num_heads=self.num_heads, hdim_kan=self.hdim_kan)\n            for i in range(self.depth)\n        ]\n        # check the origin type of the block is torch.nn.modules.container.Sequential\n        # if the origin type is torch.nn.modules.container.Sequential, then we need to convert it to a list\n        if type(self.blocks) == nn.Sequential:\n            self.blocks = nn.Sequential(*block_list)\n        elif type(self.blocks) == nn.ModuleList:\n            self.blocks = nn.ModuleList(block_list)\n\n\n\nclass DistilledVisionTransformer(VisionTransformer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))\n        num_patches = self.patch_embed.num_patches\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))\n        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.dist_token, std=.02)\n        trunc_normal_(self.pos_embed, std=.02)\n        self.head_dist.apply(self._init_weights)\n\n    def forward_features(self, x):\n        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n        # with slight modifications to add the dist_token\n        B = x.shape[0]\n        x = self.patch_embed(x)\n\n        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks\n        dist_token = self.dist_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, dist_token, x), dim=1)\n\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x = self.norm(x)\n        return x[:, 0], x[:, 1]\n\n    def forward(self, x):\n        x, x_dist = self.forward_features(x)\n        x = self.head(x)\n        x_dist = self.head_dist(x_dist)\n        if self.training:\n            return x, x_dist\n        else:\n            # during inference, return the average of both classifier predictions\n            return (x + x_dist) / 2\n\ndef create_kan(model_name, pretrained, **kwargs):\n\n    if model_name == 'deit_tiny_patch16_224_KAN':\n        model = VisionKAN(\n        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n    \n    elif model_name == 'deit_small_patch16_224_KAN':\n        model = VisionKAN(\n        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_patch16_224_KAN':\n        model = VisionKAN(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_patch16_384_KAN':\n        model = VisionKAN(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_tiny_distilled_patch16_224_KAN':\n        raise RuntimeError('Distilled models are not yet implmented in KAN')\n\n    elif model_name == 'deit_small_distilled_patch16_224_KAN':\n        raise RuntimeError('Distilled models are not yet implmented in KAN')\n\n    elif model_name == 'deit_base_distilled_patch16_224_KAN':\n        raise RuntimeError('Distilled models are not yet implmented in KAN')\n\ndef create_ViT(model_name, pretrained, **kwargs):\n    if 'batch_size' in kwargs:\n        del kwargs['batch_size']\n    if model_name == 'deit_base_patch16_224_ViT':\n        model = VisionTransformer(\n        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n    \n    elif model_name == 'deit_small_patch16_224_ViT':\n        model = VisionTransformer(\n        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_patch16_224_ViT':\n        model = VisionTransformer(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_patch16_384_ViT':\n        model = VisionTransformer(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_tiny_distilled_patch16_224_ViT':\n        model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n    \n    elif model_name == 'deit_small_distilled_patch16_224_ViT':\n        model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_distilled_patch16_224_ViT':\n        model = DistilledVisionTransformer(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n    elif model_name == 'deit_base_distilled_patch16_384_ViT':\n        model = DistilledVisionTransformer(\n        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)\n        model.default_cfg = _cfg()\n        if pretrained:\n            checkpoint = torch.hub.load_state_dict_from_url(\n                url=\"https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth\",\n                map_location=\"cpu\", check_hash=True\n            )\n            model.load_state_dict(checkpoint[\"model\"])\n        return model\n\n\ndef create_model(model_name,**kwargs):\n    pretrained = kwargs['pretrained'] if 'pretrained' in kwargs else False\n    if 'pretrained' in kwargs:\n        del kwargs['pretrained']\n    print(kwargs)\n    if model_name in __all__KAN:\n        model = create_kan(model_name, pretrained, **kwargs)\n        model.default_cfg = _cfg()\n        return model\n    elif model_name in __all__ViT:\n        model = create_ViT(model_name, pretrained, **kwargs)\n        model.default_cfg = _cfg()\n        return model\n    else:\n        raise RuntimeError('Unknown model (%s)' % model_name)\n\nif __name__ == '__main__':\n    model = deit_tiny_patch16_224().cuda()\n    img = torch.randn(5, 3, 224, 224).cuda()\n    out = model(img)\n    print(out.shape)\n\n"
  },
  {
    "path": "pyproject.toml",
    "content": "# pyproject.toml\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poetry]\nname = \"viskan\"\nversion = \"0.1.0\"\ndescription = \"Viskan\"\nauthors = [\"Saaketh Koundinya <three2saki@gmail.com>\"]\n\n[tool.poetry.dependencies]\npython >= \"^3.6\"\ntorch >= \"1.13.1\"\ntorchvision >= \"0.8.1\"\ntimm >= \"0.3.2\"\n\n[tool.poetry.dev-dependencies]\n# Add any development dependencies here if needed\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.13.1\ntorchvision==0.8.1\ntimm==0.3.2\n"
  },
  {
    "path": "samplers.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\nimport torch\nimport torch.distributed as dist\nimport math\n\n\nclass RASampler(torch.utils.data.Sampler):\n    \"\"\"Sampler that restricts data loading to a subset of the dataset for distributed,\n    with repeated augmentation.\n    It ensures that different each augmented version of a sample will be visible to a\n    different process (GPU)\n    Heavily based on torch.utils.data.DistributedSampler\n    \"\"\"\n\n    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        if num_repeats < 1:\n            raise ValueError(\"num_repeats should be greater than 0\")\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.num_repeats = num_repeats\n        self.epoch = 0\n        self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))\n        self.total_size = self.num_samples * self.num_replicas\n        # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))\n        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))\n        self.shuffle = shuffle\n\n    def __iter__(self):\n        if self.shuffle:\n            # deterministically shuffle based on epoch\n            g = torch.Generator()\n            g.manual_seed(self.epoch)\n            indices = torch.randperm(len(self.dataset), generator=g)\n        else:\n            indices = torch.arange(start=0, end=len(self.dataset))\n\n        # add extra samples to make it evenly divisible\n        indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()\n        padding_size: int = self.total_size - len(indices)\n        if padding_size > 0:\n            indices += indices[:padding_size]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        return iter(indices[:self.num_selected_samples])\n\n    def __len__(self):\n        return self.num_selected_samples\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n"
  },
  {
    "path": "utils.py",
    "content": "# Copyright (c) 2015-present, Facebook, Inc.\n# All rights reserved.\n\"\"\"\nMisc functions, including distributed helpers.\n\nMostly copy-paste from torchvision references.\n\"\"\"\nimport io\nimport os\nimport time\nfrom collections import defaultdict, deque\nimport datetime\n\nimport torch\nimport torch.distributed as dist\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef _load_checkpoint_for_ema(model_ema, checkpoint):\n    \"\"\"\n    Workaround for ModelEma._load_checkpoint to accept an already-loaded object\n    \"\"\"\n    mem_file = io.BytesIO()\n    torch.save({'state_dict_ema':checkpoint}, mem_file)\n    mem_file.seek(0)\n    model_ema._load_checkpoint(mem_file)\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    import builtins as __builtin__\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef init_distributed_mode(args):\n    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print('Not using distributed mode')\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}'.format(\n        args.rank, args.dist_url), flush=True)\n    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n                                         world_size=args.world_size, rank=args.rank)\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n"
  }
]