[
  {
    "path": ".gitignore",
    "content": "saved_models/**\n.vscode/**\n*.pt\n*.pth\nbackbones/**\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Huiping Zhuang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "[中文](README_CN.md) ｜ English\n\n# Analytic Continual Learning\n\nOfficial implementation of the following papers.\n\n[1] Zhuang, Huiping, et al. \"[ACIL: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection.](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4b74a42fc81fc7ee252f6bcb6e26c8be-Abstract-Conference.html)\" Advances in Neural Information Processing Systems 35 (2022): 11602-11614.\n\n[2] Zhuang, Huiping, et al. \"[GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.](https://openaccess.thecvf.com/content/CVPR2023/html/Zhuang_GKEAL_Gaussian_Kernel_Embedded_Analytic_Learning_for_Few-Shot_Class_Incremental_CVPR_2023_paper.html)\" Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.\n\n[3] Zhuang, Huiping, et al. \"[DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.](https://ojs.aaai.org/index.php/AAAI/article/view/29670)\" Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.\n\n[4] Zhuang, Huiping, et al. \"[GACL: Exemplar-Free Generalized Analytic Continual Learning](https://neurips.cc/virtual/2024/poster/95330)\" Advances in Neural Information Processing Systems 37 (2024). [[OpenReview]](https://openreview.net/forum?id=P6aJ7BqYlc) [[arXiv]](https://arxiv.org/abs/2403.15706)\n\n[5] Zhuang, Huiping, et al. \"[Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task.](https://ieeexplore.ieee.org/document/10721370)\" IEEE Transactions on Vehicular Technology (2024).\n\n[6] Fang, Di, et al. \"[AIR: Analytic Imbalance Rectifier for Continual Learning.](https://arxiv.org/abs/2408.10349)\" arXiv preprint arXiv:2408.10349 (2024).\n\n![](figures/acc_cmp.jpg)\n\n**Welcome to join our Tencent QQ group: [954528161](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=qaK4W8Jw6d--VWHlx7iUs93T2qMJT9k_&authKey=5e8hSXX8rALjM12iGwrZ9BmRBP9iUfCuRGNCaZ3%2Bx0msiRFVcSwu%2FuZpeKig1XQH&noverify=0&group_code=954528161). Chinese tutorial is available at [Bilibili](https://www.bilibili.com/video/BV1wq421A7YM/).**\n\n## Dual Branch\n\nWe have a dual branch at \"[Analytic Federated Learning.](https://github.com/ZHUANGHP/Analytic-federated-learning)\" \n\n## Environment\nWe recommend using the [Anaconda](https://anaconda.org/) to install the development environment.\n\n```bash\ngit clone --depth=1 git@github.com:ZHUANGHP/Analytic-continual-learning.git\n\ncd Analytic-continual-learning\nconda env create -f environment.yaml\nconda activate AL\n\nmkdir backbones\n```\n\nDownload the pre-train weight at the [release page](https://github.com/ZHUANGHP/Analytic-continual-learning/releases) for quick start. We suggest you to extract the pre-train backbone (zip file) under the `backbones` folder.\n\nFor the macOS users and the CPU-only users, you need deleted the items related to CUDA in the `environment.yaml` file.\n\nWe highly recommend you to run our code in Linux. Windows and macOS users are also welcome to submit issues if they have problems running the code.\n\n## Quick Start\nPut the base training weights (provided at the [release page](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)) at the `backbones` directory. Gradients are not used in continuous learning. **You can run our code even on CPUs.**\n\nHere are some examples.\n\n```bash\n# ACIL (CIFAR-100, B50 25 phases)\npython main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# G-ACIL (CIFAR-100, B50 25 phases)\npython main.py G-ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# GKEAL (CIFAR-100, B50 10 phases)\npython main.py GKEAL --dataset CIFAR-100 --base-ratio 0.5 --phases 10 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --sigma 10 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# DS-AL (CIFAR-100, B50 50 phases)\npython main.py DS-AL --dataset CIFAR-100 --base-ratio 0.5 --phases 50 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 0.6 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# DS-AL (ImageNet-1k, B50 20 phases)\npython main.py DS-AL --dataset ImageNet-1k --base-ratio 0.5 --phases 20 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet18 \\\n    --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 1.5 --buffer-size 16384 \\\n    --cache-features --backbone-path ./backbones/resnet18_ImageNet-1k_0.5_None\n```\n\n## Training From Scratch\n\n```bash\n# ACIL (CIFAR-100)\npython main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet32 \\\n    --learning-rate 0.5 --label-smoothing 0 --base-epochs 300 --weight-decay 5e-4 \\\n    --gamma 0.1 --buffer-size 8192 --cache-features --IL-batch-size 4096\n```\n```bash\n# ACIL (ImageNet-1k)\npython main.py ACIL --dataset ImageNet-1k --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet18 \\\n    --learning-rate 0.5 --label-smoothing 0.05 --base-epochs 300 --weight-decay 5e-5 \\\n    --gamma 0.1 --buffer-size 16384 --cache-features --IL-batch-size 4096\n```\n\n## Reproduction Details\n\n### Difference Between the ACIL and the G-ACIL\n\nThe G-ACIL is a general version of the ACIL for the general CIL setting. For the tradition CIL setting, the G-ACIL is equivalent to the ACIL. Thus, we use the same implementation in this repository.\n\n### Benchmarks (B50, 25 phases, with `TrivialAugmentWide`)\n\nMetrics are shown in 95% confidence intervals ($\\mu \\pm 1.96\\sigma$).\n\n|   Dataset   | Method         | Backbone  | Buffer Size | Average Accuracy (%) | Last Phase Accuracy (%) |\n| :---------: | :------------: | :-------: | :---------: | :------------------: | :---------------------: |\n|  CIFAR-100  |  ACIL & G-ACIL | ResNet-32 |    8192     |   $71.047\\pm0.252$   |    $63.384\\pm0.330$     |\n|  CIFAR-100  |  DS-AL         | ResNet-32 |    8192     |   $71.277\\pm0.251$   |    $64.043\\pm0.184$     |\n|  CIFAR-100  |  GKEAL         | ResNet-32 |    8192     |   $70.371\\pm0.168$   |    $62.301\\pm0.191$     |\n| ImageNet-1k |  ACIL & G-ACIL | ResNet-18 |    16384    |   $67.497\\pm0.092$   |    $58.349\\pm0.111$     |\n| ImageNet-1k |  DS-AL         | ResNet-18 |    16384    |   $68.354\\pm0.084$   |    $59.762\\pm0.086$     |\n| ImageNet-1k |  GKEAL         | ResNet-18 |    16384    |   $66.881\\pm0.061$   |    $57.295\\pm0.105$     |\n\n![Top-1 Accuracy](figures/acc@1.svg)\n\n### Hyper-Parameters (Analytic Continual Learning)\nThe backbones are frozen during the incremental learning process of our algorithm. You can use the `--cache-features` option to save the features output by the backbones to improve the efficiency of parameter adjustment.\n\n1. **Buffer Size**\n\n    For the ACIL, the buffer size means the *expansion size* of the random projection layer. For the GKEAL, the buffer size means the number of *center vectors* of the *Gaussian kernel embedding*. We summarize the \"random projection\" and the \"Gaussian projection\" into one concept \"buffer\" in the DS-AL.\n\n    On most datasets, the performance of the algorithm first increases and then decreases as the buffer size increases. You can see further experiments on this hyperparameter in our papers. We recommend using a buffer size of 8192 on CIFAR-100 and 16384 or greater on ImageNet for optimal performance. A larger buffer size requires more memory.\n\n2. **$\\gamma$ (Coefficient of the Regularization Term)**\n\n    For the dataset used in the papers, $\\gamma$ is insensitive within a interval. However, a $\\gamma$ that is too small may cause numerical stability problems in matrix inversion, and a $\\gamma$ that is too large may cause under-fitting of the classifier. On both CIFAR-100 and ImageNet-1k, $\\gamma$ is 0.1. When you migrate our algorithm to other datasets, we still recommend that you do some experiments to check whether $\\gamma$ is appropriate.\n\n3. **$\\beta$ and $\\sigma$ (GKEAL Only)**\n\n    In the GKEAL, the width-adjusting parameter $\\beta$ controls the width of the Gaussian kernels. There is a comfortable range for $\\sigma$ at around $[5, 15]$ for CIFAR-100 and ImageNet-1k that gives good results, where $\\beta = \\frac{1}{2\\sigma^2}$.\n\n4. **Compensation Ratio $\\mathcal{C}$ (DS-AL Only)**\n\n    We recommend using the grid search to find the best compensation ratio in the interval $[0, 2]$. The best value is 0.6 for the CIFAR-100, while the best value for the ImageNet-1k is 1.5.\n\nFurther analysis on hyper-parameters are shown in our papers.\n\n### Hyper-Parameters (Base Training)\nIn the base training process, the backbones reaches over 80% top-1 accuracy on the first half of CIFAR-100 (ResNet-32) and ImageNet-1k (ResNet-18). Important hyper-parameters are listed below.\n\n1. **Learning Rate**\n\n    In this implementation, we use a cosine scheduler instead of choosing the same piece-wise smooth scheduler as in the papers to reduce the number of hyper-parameters. We recommend using a learning rate of 0.5 (when the batch size is 256) on CIFAR-100 and ImageNet-1k to obtain better convergence. The number of epochs we use for provided backbones is 300.\n\n2. **Label Smoothing and Weight Decay**\n\n    Properly setting label smoothing and weight decay can help prevent over-fitting of the backbone. In CIFAR-100, label smoothing is not significantly helpful; while in ImageNet-1k, we empirically selected 0.05. For CIFAR-100, we choose a weight decay of 5e-4, while in ImageNet-1k, this value is 5e-5.\n\n3. **Image Augmentation**\n\n    Using image augmentation to obtain a more generalizable backbone in the base training dataset can significantly improve performance. No image augmentation is used in the experiments of our papers. But in this implementation, data augmentation is enabled on by default. **So using this implementation will achieve higher performance than reported in the papers (about 2%~5%)**.\n\n    Note that we do not use any data augmentation during the re-alignment and the continual learning processes because each sample will be learned only once.\n\n# Cite Our Papers\n\n```bib\n@InProceedings{ACIL_Zhuang_NeurIPS2022,\n    author    = {Zhuang, Huiping and Weng, Zhenyu and Wei, Hongxin and Xie, Renchunzi and Toh, Kar-Ann and Lin, Zhiping},\n    title     = {{ACIL}: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection},\n    booktitle = {Advances in Neural Information Processing Systems},\n    editor    = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},\n    pages     = {11602--11614},\n    publisher = {Curran Associates, Inc.},\n    volume    = {35},\n    year      = {2022},\n    url       = {https://proceedings.neurips.cc/paper_files/paper/2022/file/4b74a42fc81fc7ee252f6bcb6e26c8be-Paper-Conference.pdf}\n}\n\n@InProceedings{GKEAL_Zhuang_CVPR2023,\n    author    = {Zhuang, Huiping and Weng, Zhenyu and He, Run and Lin, Zhiping and Zeng, Ziqian},\n    title     = {{GKEAL}: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task},\n    booktitle = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n    month     = jun,\n    year      = {2023},\n    pages     = {7746--7755},\n    doi       = {10.1109/CVPR52729.2023.00748}\n}\n\n@Article{DS-AL_Zhuang_AAAI2024,\n    title   = {{DS-AL}: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning},\n    author  = {Zhuang, Huiping and He, Run and Tong, Kai and Zeng, Ziqian and Chen, Cen and Lin, Zhiping},\n    journal = {Proceedings of the AAAI Conference on Artificial Intelligence},\n    volume  = {38},\n    number  = {15},\n    pages   = {17237--17244},\n    year    = {2024},\n    month   = mar,\n    doi     = {10.1609/aaai.v38i15.29670},\n    url     = {https://ojs.aaai.org/index.php/AAAI/Article/view/29670}\n}\n\n@InProceedings{GACL_Zhuang_NeurIPS2024,\n    title     = {{GACL}: Exemplar-Free Generalized Analytic Continual Learning},\n    author    = {Huiping Zhuang and Yizhu Chen and Di Fang and Run He and Kai Tong and Hongxin Wei and Ziqian Zeng and Cen Chen},\n    year      = {2024},\n    booktitle = {Advances in Neural Information Processing Systems},\n    publisher = {Curran Associates, Inc.},\n    month     = dec\n}\n\n@article{AEF-OCL_Zhuang_TVT2024,\n    title   = {Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task},\n    author  = {Zhuang, Huiping and Fang, Di and Tong, Kai and Liu, Yuchen and Zeng, Ziqian and Zhou, Xu and Chen, Cen},\n    year    = {2024},\n    journal = {IEEE Transactions on Vehicular Technology},\n    pages   = {1-10},\n    doi     = {10.1109/TVT.2024.3483557}\n}\n\n@misc{AIR_Fang_arXiv2024,\n    title         = {{AIR}: Analytic Imbalance Rectifier for Continual Learning}, \n    author        = {Di Fang and Yinan Zhu and Zhiping Lin and Cen Chen and Ziqian Zeng and Huiping Zhuang},\n    year          = {2024},\n    month         = aug,\n    archivePrefix = {arXiv},\n    primaryClass  = {cs.LG},\n    eprint        = {2408.10349},\n    doi           = {10.48550/arXiv.2408.10349},\n    url           = {https://arxiv.org/abs/2408.10349},\n}\n```\n"
  },
  {
    "path": "README_CN.md",
    "content": "中文 ｜ [English](README.md)\n\n# 解析持续学习 (Analytic Continual Learning)\n\n该项目的工作已被发表在以下论文中：\n\n[1] Zhuang, Huiping, et al. \"[ACIL: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection.](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4b74a42fc81fc7ee252f6bcb6e26c8be-Abstract-Conference.html)\" Advances in Neural Information Processing Systems 35 (2022): 11602-11614.\n\n[2] Zhuang, Huiping, et al. \"[GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.](https://openaccess.thecvf.com/content/CVPR2023/html/Zhuang_GKEAL_Gaussian_Kernel_Embedded_Analytic_Learning_for_Few-Shot_Class_Incremental_CVPR_2023_paper.html)\" Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.\n\n[3] Zhuang, Huiping, et al. \"[DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.](https://ojs.aaai.org/index.php/AAAI/article/view/29670)\" Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.\n\n[4] Zhuang, Huiping, et al. \"[GACL: Exemplar-Free Generalized Analytic Continual Learning](https://neurips.cc/virtual/2024/poster/95330)\" Advances in Neural Information Processing Systems 37 (2024). [[OpenReview]](https://openreview.net/forum?id=P6aJ7BqYlc) [[arXiv]](https://arxiv.org/abs/2403.15706)\n\n[5] Zhuang, Huiping, et al. \"[Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task.](https://ieeexplore.ieee.org/document/10721370)\" IEEE Transactions on Vehicular Technology (2024).\n\n[6] Fang, Di, et al. \"[AIR: Analytic Imbalance Rectifier for Continual Learning.](https://arxiv.org/abs/2408.10349)\" arXiv preprint arXiv:2408.10349 (2024).\n\n![](figures/acc_cmp.jpg)\n\n**欢迎加入我们的交流QQ群: [954528161](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=qaK4W8Jw6d--VWHlx7iUs93T2qMJT9k_&authKey=5e8hSXX8rALjM12iGwrZ9BmRBP9iUfCuRGNCaZ3%2Bx0msiRFVcSwu%2FuZpeKig1XQH&noverify=0&group_code=954528161)。中文解读教程可以在[Bilibili](https://www.bilibili.com/video/BV1wq421A7YM/)中观看。**\n\n## 解析学习的另一分支：解析联邦学习\n我们开源了解析学习的另一分支——“[解析联邦学习](https://github.com/ZHUANGHP/Analytic-federated-learning)”相关的工作。\n\n## 环境配置 (Environment)\n我们建议使用[Anaconda](https://anaconda.org/)来配置运行环境。\n\n```bash\ngit clone --depth=1 git@github.com:ZHUANGHP/Analytic-continual-learning.git\n\ncd Analytic-continual-learning\nconda env create -f environment.yaml\nconda activate AL\n\nmkdir backbones\n```\n\n您可以从[这里](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)下载预训练模型，来快速上手体验我们的算法。我们建议将预训练的骨干网络（backbone）提取在`backbones`文件夹下。\n\n对于使用macOS系统或使用CPUs计算的用户，您需要删除`environment.yaml`文件中有关CUDA的项。\n\n我们强烈建议您在Linux中运行我们的算法。当然如果Windows和macOS用户在运行时遇到任何问题也欢迎提交Issues。\n\n## 快速开始 (Quick Start)\n在开始体验算法之前，请您先将基础训练权重（该[发布页](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)中提供）放入`backbones`目录。由于本算法持续学习阶段不需要进行梯度计算，**您甚至可以在CPUs上运行我们的代码。**\n\n这是一些参考案例：\n\n```bash\n# ACIL (CIFAR-100, B50 25 phases)\npython main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# G-ACIL (CIFAR-100, B50 25 phases)\npython main.py G-ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# GKEAL (CIFAR-100, B50 10 phases)\npython main.py GKEAL --dataset CIFAR-100 --base-ratio 0.5 --phases 10 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --sigma 10 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# DS-AL (CIFAR-100, B50 50 phases)\npython main.py DS-AL --dataset CIFAR-100 --base-ratio 0.5 --phases 50 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \\\n    --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 0.6 --buffer-size 8192 \\\n    --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None\n```\n```bash\n# DS-AL (ImageNet-1k, B50 20 phases)\npython main.py DS-AL --dataset ImageNet-1k --base-ratio 0.5 --phases 20 \\\n    --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet18 \\\n    --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 1.5 --buffer-size 16384 \\\n    --cache-features --backbone-path ./backbones/resnet18_ImageNet-1k_0.5_None\n```\n\n## 从零开始训练 (Training From Scratch)\n\n```bash\n# ACIL (CIFAR-100)\npython main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet32 \\\n    --learning-rate 0.5 --label-smoothing 0 --base-epochs 300 --weight-decay 5e-4 \\\n    --gamma 0.1 --buffer-size 8192 --cache-features --IL-batch-size 4096\n```\n```bash\n# ACIL (ImageNet-1k)\npython main.py ACIL --dataset ImageNet-1k --base-ratio 0.5 --phases 25 \\\n    --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet18 \\\n    --learning-rate 0.5 --label-smoothing 0.05 --base-epochs 300 --weight-decay 5e-5 \\\n    --gamma 0.1 --buffer-size 16384 --cache-features --IL-batch-size 4096\n```\n\n## 复现的细节 (Reproduction Details)\n\n### ACIL与G-ACIL之间的区别\n\nG-ACIL是用于一般CIL设置的ACIL的通用版本。在传统的CIL任务上，G-ACIL相当于ACIL。因此我们在本仓库中使用了相同的实现。\n\n### 基准测试(B50，25阶段，使用`TrivialAugmentWide`)\n\n以下指标有95%的置信水平($\\mu \\pm 1.96\\sigma$)：\n\n|   Dataset   |    Method     | Backbone  | Buffer Size | Average Accuracy (%) | Last Phase Accuracy (%) |\n| :---------: | :-----------: | :-------: | :---------: | :------------------: | :---------------------: |\n|  CIFAR-100  | ACIL & G-ACIL | ResNet-32 |    8192     |   $71.047\\pm0.252$   |    $63.384\\pm0.330$     |\n|  CIFAR-100  |     DS-AL     | ResNet-32 |    8192     |   $71.277\\pm0.251$   |    $64.043\\pm0.184$     |\n|  CIFAR-100  |     GKEAL     | ResNet-32 |    8192     |   $70.371\\pm0.168$   |    $62.301\\pm0.191$     |\n| ImageNet-1k | ACIL & G-ACIL | ResNet-18 |    16384    |   $67.497\\pm0.092$   |    $58.349\\pm0.111$     |\n| ImageNet-1k |     DS-AL     | ResNet-18 |    16384    |   $68.354\\pm0.084$   |    $59.762\\pm0.086$     |\n| ImageNet-1k |     GKEAL     | ResNet-18 |    16384    |   $66.881\\pm0.061$   |    $57.295\\pm0.105$     |\n\n![Top-1 Accuracy](figures/acc@1.svg)\n\n### 超参数(持续学习阶段)\n在算法的持续学习阶段中，骨干网络（backbone）是被冻结的。您可以使用`--cache-features`选项保存骨干网络输出的特征，以提高参数调整的效率。下面列出了一些重要的超参数：\n\n1. **Buffer Size**\n\n    对于ACIL，缓冲区（buffer）大小表示随机投影层的 *扩展尺寸（expansion size）* 。对于GKEAL，缓冲区大小是指 *高斯核嵌入（Gaussian kernel embedding）* 的 *中心向量（center vectors）* 的个数。在DS-AL中，我们将“随机投影（random projection）”和“高斯投影（Gaussian projection）”均归为“缓冲区”这一概念。\n\n    在大多数数据集上，随着缓冲区大小的增加，本算法的性能先增加后降低，您可以在我们的论文中看到关于这个超参数的详细实验。根据实验，我们建议在CIFAR-100上将缓冲区大小设置为8192，在ImageNet-1k上将缓冲区大小设置为16384或更大，以获得最佳性能。当然，较大的缓冲区大小代表着更多的内存占用。\n\n2. **$\\gamma$(正则化项的系数)**\n\n    对于论文中使用的数据集， $\\gamma$ 在一定范围内不敏感。但是，太小的 $\\gamma$ 可能会导致矩阵求逆过程所得数值不稳定，而太大的 $\\gamma$ 可能会导致分类器欠拟合。根据实验，我们在CIFAR-100和ImageNet-1k上将 $\\gamma$ 均设置为0.1。若您计划将我们的算法应用到其他数据集时，我们还是建议您做一些实验以检查 $\\gamma$ 是否合适，避免无法充分发挥算法性能。\n\n3. **$\\beta$ and $\\sigma$(只在GKEAL中设置)**\n\n    在GKEAL中，宽度调整（width-adjusting）参数 $\\beta$ 控制高斯核（Gaussian kernels）的宽度。对于CIFAR-100和ImageNet-1k， $\\sigma$ 设置在 $[5, 15]$ 左右时效果会较好，这里有转换关系 $\\beta = \\frac{1}{2\\sigma^2}$ 。\n\n4. **Compensation Ratio $\\mathcal{C}$(只在DS-AL中设置)**\n\n    我们建议使用网格搜索（grid search）在区间 $[0,2]$ 中找到最佳补偿比（compensation ratio）。根据实验，我们建议在CIFAR-100和ImageNet-1k上分别将补偿比设置为0.6和1.5。\n\n更为详细的超参数设置工作您可以在我们的论文中查阅。\n\n### 超参数(基础训练阶段)\n在基础训练阶段中，骨干网络在CIFAR-100 (ResNet-32)和ImageNet-1k (ResNet-18)的前半数据集上达到了80%以上的top-1准确率。下面列出了一些重要的超参数：\n\n1. **Learning Rate**\n\n    在本仓库实现中，我们使用了“余弦调整器（cosine scheduler）”，而不是像论文中那样使用“分段平滑调整器（piece-wise smooth scheduler）”，这可以有效的减少需要设置的超参数数量。我们建议在CIFAR-100和ImageNet-1k上将学习率设置为0.5（当批大小为256时）以获得更好的收敛性。此外，提供的骨干网络训练的epoch数是300。\n\n2. **Label Smoothing and Weight Decay**\n\n    适当的设置标签平滑（label smoothing）和权重衰减（weight decay）可以防止骨干网络过拟合。有关标签平滑，在CIFAR-100中设置该参数没有显著的效果，在ImageNet-1k中我们设置为0.05；有关权重衰减，在CIFAR-100中我们设置为5e-4，在ImageNet-1k中我们设置为5e-5。\n\n3. **Image Augmentation**\n\n    在基础训练数据集中使用图像增强可以获得有更好泛化能力的骨干网络，能够显著提高性能。在论文的实验中，我们并没有使用图像增强。而在本仓库实现中，默认情况下我们设置启用了数据增强。**因此使用本仓库实现有着相比论文中指标更高的性能（约2%~5%）。**\n\n    请注意，在重新对齐（re-alignment）和持续学习过程中，由于每个样本只学习一次，我们没有使用任何数据增强。\n\n# 欢迎引用我们的论文\n\n```bib\n@InProceedings{ACIL_Zhuang_NeurIPS2022,\n    author    = {Zhuang, Huiping and Weng, Zhenyu and Wei, Hongxin and Xie, Renchunzi and Toh, Kar-Ann and Lin, Zhiping},\n    title     = {{ACIL}: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection},\n    booktitle = {Advances in Neural Information Processing Systems},\n    editor    = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},\n    pages     = {11602--11614},\n    publisher = {Curran Associates, Inc.},\n    volume    = {35},\n    year      = {2022},\n    url       = {https://proceedings.neurips.cc/paper_files/paper/2022/file/4b74a42fc81fc7ee252f6bcb6e26c8be-Paper-Conference.pdf}\n}\n\n@InProceedings{GKEAL_Zhuang_CVPR2023,\n    author    = {Zhuang, Huiping and Weng, Zhenyu and He, Run and Lin, Zhiping and Zeng, Ziqian},\n    title     = {{GKEAL}: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task},\n    booktitle = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n    month     = jun,\n    year      = {2023},\n    pages     = {7746--7755},\n    doi       = {10.1109/CVPR52729.2023.00748}\n}\n\n@Article{DS-AL_Zhuang_AAAI2024,\n    title   = {{DS-AL}: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning},\n    author  = {Zhuang, Huiping and He, Run and Tong, Kai and Zeng, Ziqian and Chen, Cen and Lin, Zhiping},\n    journal = {Proceedings of the AAAI Conference on Artificial Intelligence},\n    volume  = {38},\n    number  = {15},\n    pages   = {17237--17244},\n    year    = {2024},\n    month   = mar,\n    doi     = {10.1609/aaai.v38i15.29670},\n    url     = {https://ojs.aaai.org/index.php/AAAI/Article/view/29670}\n}\n\n@InProceedings{GACL_Zhuang_NeurIPS2024,\n    title     = {{GACL}: Exemplar-Free Generalized Analytic Continual Learning},\n    author    = {Huiping Zhuang and Yizhu Chen and Di Fang and Run He and Kai Tong and Hongxin Wei and Ziqian Zeng and Cen Chen},\n    year      = {2024},\n    booktitle = {Advances in Neural Information Processing Systems},\n    publisher = {Curran Associates, Inc.},\n    month     = dec\n}\n\n@article{AEF-OCL_Zhuang_TVT2024,\n    title       = {Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task},\n    author      = {Zhuang, Huiping and Fang, Di and Tong, Kai and Liu, Yuchen and Zeng, Ziqian and Zhou, Xu and Chen, Cen},\n    year        = {2024},\n    journal     = {IEEE Transactions on Vehicular Technology},\n    pages       = {1-10},\n    doi         = {10.1109/TVT.2024.3483557}\n}\n\n@misc{AIR_Fang_arXiv2024,\n    title         = {{AIR}: Analytic Imbalance Rectifier for Continual Learning}, \n    author        = {Di Fang and Yinan Zhu and Zhiping Lin and Cen Chen and Ziqian Zeng and Huiping Zhuang},\n    year          = {2024},\n    month         = aug,\n    archivePrefix = {arXiv},\n    primaryClass  = {cs.LG},\n    eprint        = {2408.10349},\n    doi           = {10.48550/arXiv.2408.10349},\n    url           = {https://arxiv.org/abs/2408.10349},\n}\n```\n"
  },
  {
    "path": "analytic/ACIL.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nImplementation of the ACIL [1] and the G-ACIL [2].\nThe G-ACIL is a generalization of the ACIL in the generalized setting.\nFor the popular setting, the G-ACIL is equivalent to the ACIL.\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection.\"\n    Advances in Neural Information Processing Systems 35 (2022): 11602-11614.\n[2] Zhuang, Huiping, et al.\n    \"G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning\"\n    arXiv preprint arXiv:2403.15706 (2024).\n\"\"\"\n\nimport torch\nfrom os import path\nfrom tqdm import tqdm\nfrom typing import Any, Dict, Optional, Sequence\nfrom utils import set_weight_decay, validate\nfrom torch._prims_common import DeviceLikeType\nfrom .Buffer import RandomBuffer\nfrom torch.nn import DataParallel\nfrom .Learner import Learner, loader_t\nfrom .AnalyticLinear import AnalyticLinear, RecursiveLinear\n\n\nclass ACIL(torch.nn.Module):\n    def __init__(\n        self,\n        backbone_output: int,\n        backbone: torch.nn.Module = torch.nn.Flatten(),\n        buffer_size: int = 8192,\n        gamma: float = 1e-3,\n        device=None,\n        dtype=torch.double,\n        linear: type[AnalyticLinear] = RecursiveLinear,\n    ) -> None:\n        super().__init__()\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.backbone = backbone\n        self.backbone_output = backbone_output\n        self.buffer_size = buffer_size\n        self.buffer = RandomBuffer(backbone_output, buffer_size, **factory_kwargs)\n        self.analytic_linear = linear(buffer_size, gamma, **factory_kwargs)\n        self.eval()\n\n    @torch.no_grad()\n    def feature_expansion(self, X: torch.Tensor) -> torch.Tensor:\n        return self.buffer(self.backbone(X))\n\n    @torch.no_grad()\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        return self.analytic_linear(self.feature_expansion(X))\n\n    @torch.no_grad()\n    def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None:\n        Y = torch.nn.functional.one_hot(y)\n        X = self.feature_expansion(X)\n        self.analytic_linear.fit(X, Y)\n\n    @torch.no_grad()\n    def update(self) -> None:\n        self.analytic_linear.update()\n\n\nclass ACILLearner(Learner):\n    \"\"\"\n    This implementation is for the G-ACIL [2], a general version of the ACIL [1] that\n    supports mini-batch learning and the general CIL setting.\n    In the traditional CIL settings, the G-ACIL is equivalent to the ACIL.\n    \"\"\"\n\n    def __init__(\n        self,\n        args: Dict[str, Any],\n        backbone: torch.nn.Module,\n        backbone_output: int,\n        device=None,\n        all_devices: Optional[Sequence[DeviceLikeType]] = None,\n    ) -> None:\n        super().__init__(args, backbone, backbone_output, device, all_devices)\n        self.learning_rate: float = args[\"learning_rate\"]\n        self.buffer_size: int = args[\"buffer_size\"]\n        self.gamma: float = args[\"gamma\"]\n        self.base_epochs: int = args[\"base_epochs\"]\n        self.warmup_epochs: int = args[\"warmup_epochs\"]\n        self.make_model()\n\n    def base_training(\n        self,\n        train_loader: loader_t,\n        val_loader: loader_t,\n        baseset_size: int,\n    ) -> None:\n        model = torch.nn.Sequential(\n            self.backbone,\n            torch.nn.Linear(self.backbone_output, baseset_size),\n        ).to(self.device, non_blocking=True)\n        model = self.wrap_data_parallel(model)\n\n        if self.args[\"separate_decay\"]:\n            params = set_weight_decay(model, self.args[\"weight_decay\"])\n        else:\n            params = model.parameters()\n        optimizer = torch.optim.SGD(\n            params,\n            lr=self.learning_rate,\n            momentum=self.args[\"momentum\"],\n            weight_decay=self.args[\"weight_decay\"],\n        )\n\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n            optimizer, T_max=self.base_epochs - self.warmup_epochs, eta_min=1e-6 # type: ignore\n        )\n        if self.warmup_epochs > 0:\n            warmup_scheduler = torch.optim.lr_scheduler.LinearLR(\n                optimizer,\n                start_factor=1e-3,\n                total_iters=self.warmup_epochs,\n            )\n            scheduler = torch.optim.lr_scheduler.SequentialLR(\n                optimizer, [warmup_scheduler, scheduler], [self.warmup_epochs]\n            )\n\n        criterion = torch.nn.CrossEntropyLoss(\n            label_smoothing=self.args[\"label_smoothing\"]\n        ).to(self.device, non_blocking=True)\n\n        best_acc = 0.0\n        logging_file_path = path.join(self.args[\"saving_root\"], \"base_training.csv\")\n        logging_file = open(logging_file_path, \"w\", buffering=1)\n        print(\n            \"epoch\",\n            \"best_acc@1\",\n            \"loss\",\n            \"acc@1\",\n            \"acc@5\",\n            \"f1-micro\",\n            \"training_loss\",\n            \"training_acc@1\",\n            \"training_acc@5\",\n            \"training_f1-micro\",\n            \"training_learning-rate\",\n            file=logging_file,\n            sep=\",\",\n        )\n\n        for epoch in range(self.base_epochs + 1):\n            if epoch != 0:\n                print(\n                    f\"Base Training - Epoch {epoch}/{self.base_epochs}\",\n                    f\"(Learning Rate: {optimizer.state_dict()['param_groups'][0]['lr']})\",\n                )\n                model.train()\n                for X, y in tqdm(train_loader, \"Training\"):\n                    X: torch.Tensor = X.to(self.device, non_blocking=True)\n                    y: torch.Tensor = y.to(self.device, non_blocking=True)\n                    assert y.max() < baseset_size\n\n                    optimizer.zero_grad(set_to_none=True)\n                    logits = model(X)\n                    loss: torch.Tensor = criterion(logits, y)\n                    loss.backward()\n                    optimizer.step()\n                scheduler.step()\n\n            # Validation on training set\n            model.eval()\n            train_meter = validate(\n                model, train_loader, baseset_size, desc=\"Training (Validation)\"\n            )\n            print(\n                f\"loss: {train_meter.loss:.4f}\",\n                f\"acc@1: {train_meter.accuracy * 100:.3f}%\",\n                f\"acc@5: {train_meter.accuracy5 * 100:.3f}%\",\n                f\"f1-micro: {train_meter.f1_micro * 100:.3f}%\",\n                sep=\"    \",\n            )\n\n            val_meter = validate(model, val_loader, baseset_size, desc=\"Testing\")\n            if val_meter.accuracy > best_acc:\n                best_acc = val_meter.accuracy\n                if epoch != 0:\n                    self.save_object(\n                        (self.backbone, X.shape[1], self.backbone_output),\n                        \"backbone.pth\",\n                    )\n\n            # Validation on testing set\n            print(\n                f\"loss: {val_meter.loss:.4f}\",\n                f\"acc@1: {val_meter.accuracy * 100:.3f}%\",\n                f\"acc@5: {val_meter.accuracy5 * 100:.3f}%\",\n                f\"f1-micro: {val_meter.f1_micro * 100:.3f}%\",\n                f\"best_acc@1: {best_acc * 100:.3f}%\",\n                sep=\"    \",\n            )\n            print(\n                epoch,\n                best_acc,\n                val_meter.loss,\n                val_meter.accuracy,\n                val_meter.accuracy5,\n                val_meter.f1_micro,\n                train_meter.loss,\n                train_meter.accuracy,\n                train_meter.accuracy5,\n                train_meter.f1_micro,\n                optimizer.state_dict()[\"param_groups\"][0][\"lr\"],\n                file=logging_file,\n                sep=\",\",\n            )\n        logging_file.close()\n        self.backbone.eval()\n        self.make_model()\n\n    def make_model(self) -> None:\n        self.model = ACIL(\n            self.backbone_output,\n            self.wrap_data_parallel(self.backbone),\n            self.buffer_size,\n            self.gamma,\n            device=self.device,\n            dtype=torch.double,\n            linear=RecursiveLinear,\n        )\n\n    @torch.no_grad()\n    def learn(\n        self,\n        data_loader: loader_t,\n        incremental_size: int,\n        desc: str = \"Incremental Learning\",\n    ) -> None:\n        self.model.eval()\n        for X, y in tqdm(data_loader, desc=desc):\n            X: torch.Tensor = X.to(self.device, non_blocking=True)\n            y: torch.Tensor = y.to(self.device, non_blocking=True)\n            self.model.fit(X, y, increase_size=incremental_size)\n\n    def before_validation(self) -> None:\n        self.model.update()\n\n    def inference(self, X: torch.Tensor) -> torch.Tensor:\n        return self.model(X)\n\n    @torch.no_grad()\n    def wrap_data_parallel(self, model: torch.nn.Module) -> torch.nn.Module:\n        if self.all_devices is not None and len(self.all_devices) > 1:\n            return DataParallel(model, self.all_devices, output_device=self.device) # type: ignore\n        return model\n"
  },
  {
    "path": "analytic/AEFOCL.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nImplementation of the AEF-OCL [1], an analytic method for imbalanced continual learning.\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task\"\n    arXiv preprint arXiv:2405.17779 (2024).\n\"\"\"\n\nfrom copy import deepcopy\nimport torch\nfrom tqdm import tqdm\nfrom .ACIL import ACILLearner, ACIL\nfrom .AnalyticLinear import AnalyticLinear, RecursiveLinear\n\n__all__ = [\"AEFOCL\", \"AEFOCLLearner\"]\n\n\nclass AEFOCL(ACIL):\n    \"\"\"\n    Network structure of the AEF-OCL [1], an analytic method for imbalanced continual learning.\n\n    References:\n    [1] Zhuang, Huiping, et al.\n        \"Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task\"\n        arXiv preprint arXiv:2405.17779 (2024).\n    \"\"\"\n\n    def __init__(\n        self,\n        backbone_output: int,\n        backbone: torch.nn.Module = torch.nn.Flatten(),\n        buffer_size: int = 8192,\n        gamma: float = 1e-3,\n        noise: float = 1,\n        device=None,\n        dtype=torch.double,\n        linear: type[AnalyticLinear] = RecursiveLinear,\n    ) -> None:\n        super().__init__(\n            backbone_output, backbone, buffer_size, gamma, device, dtype, linear\n        )\n        self._linear_log = dict()\n        # History prototype\n        self.noise = noise\n        # Expectation of the prototypes E[X]\n        self.register_buffer(\"ex\", torch.zeros((0, backbone_output), dtype=torch.double))\n        self.ex: torch.Tensor\n        # Expectation of the squares of the prototypes E[X^2]\n        self.register_buffer(\"ex2\", torch.zeros((0, backbone_output), dtype=torch.double))\n        self.ex2: torch.Tensor\n        # Number of the samples of the prototypes\n        self.register_buffer(\"cnt\", torch.zeros((0,), dtype=torch.long))\n        self.cnt: torch.Tensor\n        # Set the device\n        self.to(device)\n\n    @torch.no_grad()\n    def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None:\n        for name, buffer in self._linear_log.items():\n            self.analytic_linear.register_buffer(name, buffer)\n        self._linear_log.clear()\n\n        X = self.backbone(X)\n        if (increment_size := int(y.max().item()) - self.ex.shape[0] + 1) > 0:\n            # self.cnt\n            tail = torch.zeros((increment_size,)).to(self.cnt)\n            self.cnt = torch.concat((self.cnt, tail), dim=0)\n            # self.ex\n            tail = torch.zeros((increment_size, self.ex.shape[1])).to(self.ex)\n            self.ex = torch.concat((self.ex, tail), dim=0)\n            # self.ex2\n            tail = torch.zeros((increment_size, self.ex2.shape[1])).to(self.ex2)\n            self.ex2 = torch.concat((self.ex2, tail), dim=0)\n\n        labels, counts = torch.unique(y, return_counts=True)\n        self.cnt[labels] += counts\n        for i in labels:\n            X_i = X[y == i]\n            # Calculate E[X]\n            self.ex[i] += torch.sum(X_i.to(self.ex), dim=0)\n            # Calculate E[X^2]\n            self.ex2[i] += torch.sum(torch.square(X_i.to(self.ex2)), dim=0)\n\n        X = self.buffer(X)\n        Y = torch.nn.functional.one_hot(y)\n        self.analytic_linear.fit(X, Y)\n\n    def update(self) -> None:\n        peak_cnt = int(self.cnt.max())\n        mean = self.proto_mean\n        std = self.proto_std\n        print(\"Counts:\", self.cnt.tolist())\n\n        # Backup the iterative classifier\n        for name, buffer in self.analytic_linear.named_buffers():\n            self._linear_log[name] = buffer.clone().detach()\n\n        aug_bar = tqdm(\n            desc=\"Augmenting\",\n            total=(peak_cnt * len(self.cnt.nonzero()) - int(self.cnt.sum())),\n        )\n        for i in self.cnt.nonzero():\n            i = int(i.item())\n            rest_cnt = int(peak_cnt - self.cnt[i])\n            while rest_cnt > 0:\n                fill_cnt = min(rest_cnt, 8192)\n                fill_y = torch.empty((fill_cnt,), dtype=torch.long).fill_(i)\n                fill_proto = torch.randn((fill_cnt, self.buffer.in_features)).to(\n                    self.buffer.weight\n                )\n                fill_proto = (\n                    fill_proto * std[i][None, :] * self.noise + mean[i][None, :]\n                )\n                fill_proto = self.buffer(fill_proto)\n                fill_y = torch.nn.functional.one_hot(fill_y)\n                self.analytic_linear.fit(fill_proto, fill_y)\n                aug_bar.update(fill_cnt)\n                rest_cnt -= fill_cnt\n        self.analytic_linear.update()\n\n    @property\n    def proto_mean(self) -> torch.Tensor:\n        return self.ex / self.cnt[:, None]\n\n    @property\n    def proto_std(self) -> torch.Tensor:\n        std = self.ex2 / self.cnt[:, None] - torch.square(self.proto_mean)\n        std[torch.isnan(std)] = 0\n        assert (std >= 0).all()\n        proto_std = torch.sqrt(std * (self.cnt / (self.cnt - 1))[:, None])\n        proto_std[torch.isnan(proto_std)] = 0\n        return proto_std\n\n\nclass AEFOCLLearner(ACILLearner):\n    \"\"\"\n    Learner of the AEF-OCL [1], an analytic method for imbalanced continual learning.\n\n    References:\n    [1] Zhuang, Huiping, et al.\n        \"Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task\"\n        arXiv preprint arXiv:2405.17779 (2024).\n    \"\"\"\n\n    def make_model(self) -> None:\n        self.model = AEFOCL(\n            self.backbone_output,\n            self.backbone,\n            self.buffer_size,\n            self.gamma,\n            device=self.device,\n            dtype=torch.double,\n            linear=RecursiveLinear,\n        )\n"
  },
  {
    "path": "analytic/AIR.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nImplementation of the AIR [1], an online exemplar-free generalized CIL approach on imbalanced datasets.\n\nReferences:\n[1] Fang, Di, et al.\n    \"AIR: Analytic Imbalance Rectifier for Continual Learning.\"\n    arXiv preprint arXiv:2408.10349 (2024).\n\"\"\"\n\nimport torch\nfrom .ACIL import ACILLearner, ACIL\nfrom .AnalyticLinear import GeneralizedARM\n\n__all__ = [\"AIR\", \"AIRLearner\", \"GeneralizedAIRLearner\"]\n\n\nclass AIR(ACIL):\n    def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None:\n        X = self.feature_expansion(X)\n        self.analytic_linear.fit(X, y)\n\n\nclass AIRLearner(ACILLearner):\n    def make_model(self) -> None:\n        self.model = AIR(\n            self.backbone_output,\n            self.wrap_data_parallel(self.backbone),\n            self.buffer_size,\n            self.gamma,\n            device=self.device,\n            dtype=torch.double,\n            linear=GeneralizedARM,\n        )\n\n\nclass GeneralizedAIRLearner(AIRLearner):\n    pass\n"
  },
  {
    "path": "analytic/AnalyticLinear.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nBasic analytic linear modules for the analytic continual learning [1-5].\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection.\"\n    Advances in Neural Information Processing Systems 35 (2022): 11602-11614.\n[2] Zhuang, Huiping, et al.\n    \"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.\"\n    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.\n[3] Zhuang, Huiping, et al.\n    \"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.\"\n    Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.\n[4] Zhuang, Huiping, et al.\n    \"G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning\"\n    arXiv preprint arXiv:2403.15706 (2024).\n[5] Fang, Di, et al.\n    \"AIR: Analytic Imbalance Rectifier for Continual Learning.\"\n    arXiv preprint arXiv:2408.10349 (2024).\n\"\"\"\n\nimport torch\nfrom torch.nn import functional as F\nfrom typing import Optional, Union\nfrom abc import abstractmethod, ABCMeta\n\n\nclass AnalyticLinear(torch.nn.Linear, metaclass=ABCMeta):\n    def __init__(\n        self,\n        in_features: int,\n        gamma: float = 1e-1,\n        bias: bool = False,\n        device: Optional[Union[torch.device, str, int]] = None,\n        dtype=torch.double,\n    ) -> None:\n        super(torch.nn.Linear, self).__init__()  # Skip the Linear class\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.gamma: float = gamma\n        self.bias: bool = bias\n        self.dtype = dtype\n\n        # Linear Layer\n        if bias:\n            in_features += 1\n        weight = torch.zeros((in_features, 0), **factory_kwargs)\n        self.register_buffer(\"weight\", weight)\n\n    @torch.inference_mode()\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        X = X.to(self.weight)\n        if self.bias:\n            X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)\n        return X @ self.weight\n\n    @property\n    def in_features(self) -> int:\n        if self.bias:\n            return self.weight.shape[0] - 1\n        return self.weight.shape[0]\n\n    @property\n    def out_features(self) -> int:\n        return self.weight.shape[1]\n\n    def reset_parameters(self) -> None:\n        # Following the equation (4) of ACIL, self.weight is set to \\hat{W}_{FCN}^{-1}\n        self.weight = torch.zeros((self.weight.shape[0], 0)).to(self.weight)\n\n    @abstractmethod\n    def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:\n        raise NotImplementedError()\n\n    def update(self) -> None:\n        assert torch.isfinite(self.weight).all(), (\n            \"Pay attention to the numerical stability! \"\n            \"A possible solution is to increase the value of gamma. \"\n            \"Setting self.dtype=torch.double also helps.\"\n        )\n\n\nclass RecursiveLinear(AnalyticLinear):\n    def __init__(\n        self,\n        in_features: int,\n        gamma: float = 1e-1,\n        bias: bool = False,\n        device: Optional[Union[torch.device, str, int]] = None,\n        dtype=torch.double,\n    ) -> None:\n        super().__init__(in_features, gamma, bias, device, dtype)\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n        # Regularized Feature Autocorrelation Matrix (RFAuM)\n        self.R: torch.Tensor\n        R = torch.eye(self.weight.shape[0], **factory_kwargs) / self.gamma\n        self.register_buffer(\"R\", R)\n\n    @torch.no_grad()\n    def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:\n        \"\"\"The core code of the ACIL and the G-ACIL.\n        This implementation, which is different but equivalent to the equations shown in [1],\n        is proposed in the G-ACIL [4], which supports mini-batch learning and the general CIL setting.\n        \"\"\"\n        X, Y = X.to(self.weight), Y.to(self.weight)\n        if self.bias:\n            X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)\n\n        num_targets = Y.shape[1]\n        if num_targets > self.out_features:\n            increment_size = num_targets - self.out_features\n            tail = torch.zeros((self.weight.shape[0], increment_size)).to(self.weight)\n            self.weight = torch.cat((self.weight, tail), dim=1)\n        elif num_targets < self.out_features:\n            increment_size = self.out_features - num_targets\n            tail = torch.zeros((Y.shape[0], increment_size)).to(Y)\n            Y = torch.cat((Y, tail), dim=1)\n\n        # Please update your PyTorch & CUDA if the `cusolver error` occurs.\n        # If you insist on using this version, doing the `torch.inverse` on CPUs might help.\n        # >>> K_inv = torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T\n        # >>> K = torch.inverse(K_inv.cpu()).to(self.weight.device)\n        K = torch.inverse(torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T)\n        # Equation (10) of ACIL\n        self.R -= self.R @ X.T @ K @ X @ self.R\n        # Equation (9) of ACIL\n        self.weight += self.R @ X.T @ (Y - X @ self.weight)\n\n\nclass GeneralizedARM(AnalyticLinear):\n    \"\"\"Analytic Re-weighting Module (ARM) for generalized CIL.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        gamma: float = 1e-1,\n        bias: bool = False,\n        device: Optional[Union[torch.device, str, int]] = None,\n        dtype=torch.double,\n    ) -> None:\n        super().__init__(in_features, gamma, bias, device, dtype)\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n        weight = torch.zeros((in_features, 0), **factory_kwargs)\n        self.register_buffer(\"weight\", weight)\n\n        A = torch.zeros((0, in_features, in_features), **factory_kwargs)\n        self.register_buffer(\"A\", A)\n\n        C = torch.zeros((in_features, 0), **factory_kwargs)\n        self.register_buffer(\"C\", C)\n\n        self.cnt = torch.zeros(0, dtype=torch.int, device=device)\n\n    @property\n    def out_features(self) -> int:\n        return self.C.shape[1]\n\n    @torch.inference_mode()\n    def fit(self, X: torch.Tensor, y: torch.Tensor) -> None:\n        X = X.to(self.weight)\n        # Bias\n        if self.bias:\n            X = torch.concat((X, torch.ones(X.shape[0], 1)), dim=-1)\n\n        # GCIL\n        num_targets = int(y.max()) + 1\n        if num_targets > self.out_features:\n            increment_size = num_targets - self.out_features\n            torch.cuda.empty_cache()\n            # Increment C\n            tail = torch.zeros((self.C.shape[0], increment_size)).to(self.weight)\n            self.C = torch.concat((self.C, tail), dim=1)\n            # Increment cnt\n            tail = torch.zeros((increment_size,)).to(self.cnt)\n            self.cnt = torch.concat((self.cnt, tail))\n            # Increment A\n            tail = torch.zeros((increment_size, self.in_features, self.in_features))\n            self.A = torch.concat((self.A, tail.to(self.A)))\n            torch.cuda.empty_cache()\n        else:\n            num_targets = self.out_features\n\n        # ACIL\n        Y = F.one_hot(y, max(num_targets, num_targets)).to(self.C)\n        self.C += X.T @ Y\n\n        # Label Balancing\n        y_labels, label_cnt = torch.unique(y, sorted=True, return_counts=True)\n        y_labels, label_cnt = y_labels.to(self.cnt.device), label_cnt.to(\n            self.cnt.device\n        )\n        self.cnt[y_labels] += label_cnt\n\n        # Accumulate\n        for i in range(num_targets):\n            X_i = X[y == i]\n            self.A[i] += X_i.T @ X_i\n\n    @torch.inference_mode()\n    def update(self):\n        cnt_inv = 1 / self.cnt.to(self.dtype)\n        cnt_inv[torch.isinf(cnt_inv)] = 0  # replace inf with 0\n        cnt_inv *= len(self.cnt) / cnt_inv.sum()\n\n        weighted_A = torch.sum(cnt_inv[:, None, None].mul(self.A), dim=0)\n        A = weighted_A + self.gamma * torch.eye(self.in_features).to(self.A)\n        C = self.C.mul(cnt_inv[None, :])\n\n        self.weight = torch.inverse(A) @ C\n"
  },
  {
    "path": "analytic/Buffer.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nBuffer layers for the analytic learning based CIL [1-4].\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection.\"\n    Advances in Neural Information Processing Systems 35 (2022): 11602-11614.\n[2] Zhuang, Huiping, et al.\n    \"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.\"\n    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.\n[3] Zhuang, Huiping, et al.\n    \"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.\"\n    Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.\n[4] Zhuang, Huiping, et al.\n    \"G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning\"\n    arXiv preprint arXiv:2403.15706 (2024).\n\"\"\"\n\nimport torch\nfrom typing import Optional, Union, Callable\nfrom abc import ABCMeta, abstractmethod\n\n\nactivation_t = Union[Callable[[torch.Tensor], torch.Tensor], torch.nn.Module]\n\n\nclass Buffer(torch.nn.Module, metaclass=ABCMeta):\n    def __init__(self) -> None:\n        super().__init__()\n\n    @abstractmethod\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError()\n\n\nclass RandomBuffer(torch.nn.Linear, Buffer):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: bool = False,\n        device=None,\n        dtype=torch.float,\n        activation: Optional[activation_t] = torch.relu_,\n    ) -> None:\n        super(torch.nn.Linear, self).__init__()\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.in_features = in_features\n        self.out_features = out_features\n        self.activation: activation_t = (\n            torch.nn.Identity() if activation is None else activation\n        )\n\n        W = torch.empty((out_features, in_features), **factory_kwargs)\n        b = torch.empty(out_features, **factory_kwargs) if bias else None\n\n        # Using buffer instead of parameter\n        self.register_buffer(\"weight\", W)\n        self.register_buffer(\"bias\", b)\n\n        # Random Initialization\n        self.reset_parameters()\n\n    @torch.no_grad()\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        X = X.to(self.weight)\n        return self.activation(super().forward(X))\n\n\nclass GaussianKernel(Buffer):\n    def __init__(\n        self, mean: torch.Tensor, sigma: float = 1, device=None, dtype=torch.float\n    ) -> None:\n        super().__init__()\n        self.device = device\n        self.dtype = dtype\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        assert len(mean.shape) == 2, \"The mean should be a 2D tensor.\"\n        mean = mean[None, :, :].to(**factory_kwargs)\n        beta = 1 / (2 * (sigma**2))\n        self.register_buffer(\"mean\", mean)\n        self.register_buffer(\"beta\", torch.tensor(beta, **factory_kwargs))\n\n    @torch.no_grad()\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        X = torch.square_(torch.cdist(X.to(self.mean), self.mean))\n        return torch.exp_(X.mul_(-self.beta))\n\n    def init(self, X: torch.Tensor, size: Optional[int] = None) -> None:\n        if size is not None:\n            if size <= X.shape[0]:\n                idx = torch.randperm(size).to(X.device)\n                X = X[idx]\n            else:\n                # The buffer size is suggested to be greater than the number of initial samples.\n                # Generate center vectors randomly\n                n_require = size - X.shape[0]\n                W_proj = torch.normal(mean=0, std=1, size=(n_require, X.shape[0])).to(X)\n                W_proj /= torch.sum(W_proj, dim=0)\n                X = torch.cat([X, W_proj @ X], dim=0)\n        self.mean = X.to(self.mean)\n"
  },
  {
    "path": "analytic/DSAL.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nImplementation of the DS-AL [1].\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.\"\n    Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.\n\"\"\"\n\nimport torch\nfrom .ACIL import ACILLearner\nfrom typing import Callable, Dict, Any, Optional, Sequence\nfrom .AnalyticLinear import AnalyticLinear, RecursiveLinear\nfrom .Buffer import activation_t, RandomBuffer\nfrom torch._prims_common import DeviceLikeType\n\n\nclass DSAL(torch.nn.Module):\n    def __init__(\n        self,\n        backbone_output: int,\n        backbone: Callable[[torch.Tensor], torch.Tensor] = torch.nn.Flatten(),\n        expansion_size: int = 8192,\n        gamma_main: float = 1e-3,\n        gamma_comp: float = 1e-3,\n        C: float = 1,\n        activation_main: activation_t = torch.relu,\n        activation_comp: activation_t = torch.tanh,\n        device=None,\n        dtype=torch.double,\n        linear: type[AnalyticLinear] = RecursiveLinear,\n    ) -> None:\n        super().__init__()\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.backbone = backbone\n        self.expansion_size = expansion_size\n        self.buffer = RandomBuffer(\n            backbone_output,\n            expansion_size,\n            activation=torch.nn.Identity(),\n            **factory_kwargs\n        )\n        # The main stream\n        self.activation_main = activation_main\n        self.main_stream = linear(expansion_size, gamma_main, **factory_kwargs)\n        # The compensation stream\n        self.C = C\n        self.activation_comp = activation_comp\n        self.comp_stream = linear(expansion_size, gamma_comp, **factory_kwargs)\n        self.eval()\n\n    @torch.no_grad()\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        X = self.buffer(self.backbone(X))\n        X_main = self.main_stream(self.activation_main(X))\n        X_comp = self.comp_stream(self.activation_comp(X))\n        return X_main + self.C * X_comp\n\n    @torch.no_grad()\n    def fit(self, X: torch.Tensor, y: torch.Tensor, increase_size: int) -> None:\n        num_classes = max(self.main_stream.out_features, int(y.max().item()) + 1)\n        Y_main = torch.nn.functional.one_hot(y, num_classes=num_classes)\n        X = self.buffer(self.backbone(X))\n\n        # Train the main stream\n        X_main = self.activation_main(X)\n        self.main_stream.fit(X_main, Y_main)\n        self.main_stream.update()\n\n        # Previous label cleansing (PLC)\n        Y_comp = Y_main - self.main_stream(X_main)\n        Y_comp[:, :-increase_size] = 0\n\n        # Train the compensation stream\n        X_comp = self.activation_comp(X)\n        self.comp_stream.fit(X_comp, Y_comp)\n\n    @torch.no_grad()\n    def update(self) -> None:\n        self.main_stream.update()\n        self.comp_stream.update()\n\n\nclass DSALLearner(ACILLearner):\n    def __init__(\n        self,\n        args: Dict[str, Any],\n        backbone: torch.nn.Module,\n        backbone_output: int,\n        device=None,\n        all_devices: Optional[Sequence[DeviceLikeType]] = None,\n    ) -> None:\n        self.gamma_comp = args[\"gamma_comp\"]\n        self.compensation_ratio = args[\"compensation_ratio\"]\n        super().__init__(args, backbone, backbone_output, device, all_devices)\n\n    def make_model(self) -> None:\n        self.model = DSAL(\n            self.backbone_output,\n            self.backbone,\n            self.buffer_size,\n            self.gamma,\n            self.gamma_comp,\n            self.compensation_ratio,\n            device=self.device,\n            dtype=torch.double,\n            linear=RecursiveLinear,\n        )\n"
  },
  {
    "path": "analytic/GKEAL.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nImplementation of the GKEAL [1].\n\nThe GKEAL is a CIL method specially proposed for the few-shot CIL.\nBut the implementation here is just a simplified version for common CIL settings.\nCompared with the method proposed in the paper, we do not perform image augmentation here.\nEach sample will only be learned once by default.\n\nReferences:\n[1] Zhuang, Huiping, et al.\n    \"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.\"\n    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.\n\"\"\"\n\nimport torch\nfrom tqdm import tqdm\nfrom typing import Dict, Any, Sequence, Optional\nfrom torch._prims_common import DeviceLikeType\nfrom .Learner import loader_t\nfrom .ACIL import ACIL, ACILLearner\nfrom .Buffer import GaussianKernel\nfrom .AnalyticLinear import AnalyticLinear, RecursiveLinear\n\n\nclass GKEAL(ACIL):\n    def __init__(\n        self,\n        backbone_output: int,\n        backbone: torch.nn.Module = torch.nn.Flatten(),\n        buffer_size: int = 512,\n        gamma: float = 1e-3,\n        sigma: float = 10,\n        device=None,\n        dtype=torch.double,\n        linear: type[AnalyticLinear] = RecursiveLinear,\n    ):\n        super().__init__(\n            backbone_output, backbone, buffer_size, gamma, device, dtype, linear\n        )\n        self.buffer = GaussianKernel(\n            torch.zeros((self.buffer_size, self.backbone_output)),\n            sigma,\n            device=device,\n            dtype=dtype,\n        )\n\n\nclass GKEALLearner(ACILLearner):\n    def __init__(\n        self,\n        args: Dict[str, Any],\n        backbone: torch.nn.Module,\n        backbone_output: int,\n        device=None,\n        all_devices: Optional[Sequence[DeviceLikeType]] = None,\n    ) -> None:\n        self.initialized = False\n        # The width-adjusting parameter β controls the width of the Gaussian kernels.\n        # There is a comfortable range for σ at around [5, 15] for CIFAR-100 and ImageNet-1k\n        # that gives good results, where β = 1 / (2σ²).\n        self.sigma = args[\"sigma\"]\n        super().__init__(args, backbone, backbone_output, device, all_devices)\n\n    def make_model(self) -> None:\n        self.model = GKEAL(\n            self.backbone_output,\n            self.backbone,\n            self.buffer_size,\n            self.gamma,\n            self.sigma,\n            device=self.device,\n            dtype=torch.double,\n            linear=RecursiveLinear,\n        )\n\n    @torch.no_grad()\n    def learn(\n        self,\n        data_loader: loader_t,\n        incremental_size: int,\n        desc: str = \"Incremental Learning\",\n    ) -> None:\n        torch.cuda.empty_cache()\n        if self.initialized:\n            return super().learn(data_loader, incremental_size, desc)\n        total_X = []\n        total_y = []\n        for X, y in tqdm(data_loader, desc=\"Selecting center vectors\"):\n            X: torch.Tensor = X.to(self.device, non_blocking=True)\n            y: torch.Tensor = y.to(self.device, non_blocking=True)\n            X = self.backbone(X)\n            total_X.append(X)\n            total_y.append(y)\n\n        self.model.buffer.init(torch.cat(total_X), self.buffer_size)\n        torch.cuda.empty_cache()\n        for X, y in tqdm(zip(total_X, total_y), total=len(total_X), desc=desc):\n            X = self.model.buffer(X)\n            Y = torch.nn.functional.one_hot(y, incremental_size)\n            self.model.analytic_linear.fit(X, Y)\n        self.model.analytic_linear.update()\n        self.initialized = True\n"
  },
  {
    "path": "analytic/Learner.py",
    "content": "import torch\nfrom os import path\nfrom abc import ABCMeta, abstractmethod\nfrom torch.utils.data import DataLoader\nfrom torch._prims_common import DeviceLikeType\nfrom typing import Union, Dict, Any, Optional, Sequence\n\nloader_t = DataLoader[Union[torch.Tensor, torch.Tensor]]\n\n\nclass Learner(metaclass=ABCMeta):\n    def __init__(\n        self,\n        args: Dict[str, Any],\n        backbone: torch.nn.Module,\n        backbone_output: int,\n        device=None,\n        all_devices: Optional[Sequence[DeviceLikeType]] = None,\n    ) -> None:\n        self.args = args\n        self.backbone = backbone\n        self.backbone_output = backbone_output\n        self.device = device\n        self.all_devices = all_devices\n        self.model: torch.nn.Module\n\n    @abstractmethod\n    def base_training(\n        self,\n        train_loader: loader_t,\n        val_loader: loader_t,\n        baseset_size: int,\n    ) -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def learn(\n        self,\n        data_loader: loader_t,\n        incremental_size: int,\n        desc: str = \"Incremental Learning\"\n    ) -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def before_validation() -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def inference(self, X: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError()\n\n    def save_object(self, model, file_name: str) -> None:\n        torch.save(model, path.join(self.args[\"saving_root\"], file_name))\n\n    def __call__(self, X: torch.Tensor) -> torch.Tensor:\n        return self.inference(X)\n"
  },
  {
    "path": "analytic/__init__.py",
    "content": "# -*- coding: utf-8 -*-\nfrom .Learner import Learner\nfrom .Buffer import Buffer, RandomBuffer, GaussianKernel\nfrom .AnalyticLinear import AnalyticLinear, RecursiveLinear\nfrom .ACIL import ACIL, ACILLearner\nfrom .DSAL import DSAL, DSALLearner\nfrom .GKEAL import GKEAL, GKEALLearner\nfrom .AEFOCL import AEFOCL, AEFOCLLearner\nfrom .AIR import AIRLearner, GeneralizedAIRLearner\n\n\n__all__ = [\n    \"Learner\",\n    \"Buffer\",\n    \"RandomBuffer\",\n    \"GaussianKernel\",\n    \"AnalyticLinear\",\n    \"RecursiveLinear\",\n    \"ACIL\",\n    \"DSAL\",\n    \"GKEAL\",\n    \"AEFOCL\",\n    \"ACILLearner\",\n    \"DSALLearner\",\n    \"GKEALLearner\",\n    \"AEFOCLLearner\",\n    \"AIRLearner\",\n    \"GeneralizedAIRLearner\",\n]\n"
  },
  {
    "path": "config.py",
    "content": "import argparse\n\nfrom models import models\nfrom typing import Any, Dict\nfrom os import path, makedirs\nfrom datasets import dataset_list\nfrom datetime import datetime\nimport yaml\nfrom sys import argv\n\nfrom analytic import (\n    ACILLearner,\n    DSALLearner,\n    GKEALLearner,\n    AEFOCLLearner,\n    AIRLearner,\n    GeneralizedAIRLearner,\n    Learner,\n)\n\nALL_METHODS: dict[str, type[Learner]] = {\n    \"ACIL\": ACILLearner,\n    \"G-ACIL\": ACILLearner,  # The G-ACIL is a generalization of the ACIL in the generalized setting.\n    \"DS-AL\": DSALLearner,\n    \"GKEAL\": GKEALLearner,\n    \"AEF-OCL\": AEFOCLLearner,\n    \"AIR\": AIRLearner,\n    \"G-AIR\": GeneralizedAIRLearner,  # The G-AIL is a generalization of the AIR for generalized CIL.\n}\n\n__all__ = [\"load_args\", \"ALL_METHODS\"]\n\n_parser = argparse.ArgumentParser(description=\"Analytic Continual Learning\")\n\n# Method Options\n_parser.add_argument(\n    \"method\",\n    choices=ALL_METHODS.keys(),\n    help=\"The method to use for continual learning.\",\n)\n\n_parser.add_argument(\n    \"--exp-name\",\n    type=str,\n    default=\"\",\n    help=\"Name of the experiment\",\n)\n\n_parser.add_argument(\n    \"--cpu-only\",\n    action=\"store_true\",\n    help=\"Run the program on CPU only.\",\n)\n\n_parser.add_argument(\n    \"--gpus\",\n    default=None,\n    type=int,\n    action=\"extend\",\n    nargs=\"+\",\n    help=\"List of GPUs to use.\",\n)\n\n# Dataset settings\n_data_group = _parser.add_argument_group(\"Dataset arguments\")\n_data_group.add_argument(\n    \"-d\",\n    \"--dataset\",\n    default=\"CIFAR-100\",\n    choices=dataset_list.keys(),\n)\n\n_data_group.add_argument(\n    \"--data-root\",\n    metavar=\"DIR\",\n    type=str,\n    help=\"Root path to the dataset\",\n    default=\"~/dataset\",\n)\n\n_data_group.add_argument(\n    \"-j\",\n    \"--num-workers\",\n    default=8,\n    type=int,\n    metavar=\"N\",\n    help=\"Number of data loading workers (default: 8)\",\n)\n\n_data_group.add_argument(\n    \"--base-ratio\",\n    default=0.5,\n    type=float,\n    help=\"The ratio of base classes in the training set.\",\n)\n\n_data_group.add_argument(\n    \"--phases\",\n    \"--tasks\",\n    default=10,\n    type=int,\n    help=\"Number of incremental phases (tasks).\",\n)\n\n_data_group.add_argument(\n    \"-b\",\n    \"--batch-size\",\n    default=256,\n    type=int,\n    metavar=\"N\",\n    help=\"The size of one mini-batch per GPU.\",\n)\n\n_data_group.add_argument(\n    \"--cache-features\",\n    action=\"store_true\",\n    help=\"Load the features extracted by the frozen backbone to speed up inference.\",\n)\n\n# Model settings\n_model_group = _parser.add_argument_group(\"Model arguments\")\n_model_group.add_argument(\n    \"-a\",\n    \"-arch\",\n    \"--backbone\",\n    type=str,\n    default=\"resnet32\",\n    help=\"model to use for training\",\n    choices=models.keys(),\n    metavar=\"ARCH\",\n)\n\n_model_group.add_argument(\n    \"--cache-path\",\n    \"--backbone-path\",\n    metavar=\"DIR\",\n    type=str,\n    help=(\n        \"Path to the base pretrain backbone.\"\n        \"If file exists, the base training will be skipped.\"\n    ),\n)\n\n# Training Settings\n_model_group.add_argument(\"--seed\", default=None, type=int, help=\"Seed for models.\")\n\n_model_group.add_argument(\n    \"--dataset-seed\", default=None, type=int, help=\"Seed for shuffling the dataset.\"\n)\n\n# Base training arguments\n_base_group = _parser.add_argument_group(\"Base training arguments\")\n\n_base_group.add_argument(\n    \"--base-epochs\",\n    default=300,\n    type=int,\n    metavar=\"N\",\n    help=\"Number of total epochs to run for base training.\",\n)\n\n_base_group.add_argument(\n    \"--warmup-epochs\",\n    default=10,\n    type=int,\n    metavar=\"N\",\n    help=\"Number of warmup epochs.\",\n)\n\n_base_group.add_argument(\n    \"-lr\",\n    \"--learning-rate\",\n    default=0.5,\n    type=float,\n    metavar=\"LR\",\n    help=\"Initial learning rate\",\n)\n\n_base_group.add_argument(\n    \"--momentum\", default=0.9, type=float, metavar=\"M\", help=\"Momentum for SGD\"\n)\n\n_base_group.add_argument(\n    \"--wd\",\n    \"--weight-decay\",\n    default=5e-4,\n    type=float,\n    metavar=\"W\",\n    dest=\"weight_decay\",\n)\n\n_base_group.add_argument(\n    \"--separate-decay\",\n    action=\"store_true\",\n    help=\"Separating the normalization parameters from the rest of the model parameters\",\n)\n\n_base_group.add_argument(\"--label-smoothing\", default=0.05, type=float)\n\n# IL hyper-parameters\n_il_group = _parser.add_argument_group(\"IL Hyper-parameters\")\n\n_il_group.add_argument(\n    \"--IL-batch-size\",\n    default=None,\n    type=int,\n    help=\"The size of mini-batch during the incremental learning process.\",\n)\n\n_il_group.add_argument(\n    \"--gamma\",\n    \"--gamma-main\",\n    default=0.1,\n    type=float,\n    help=\"The regularization of the (main stream) linear classifier.\",\n)\n\n_il_group.add_argument(\n    \"--buffer-size\",\n    \"--expansion-size\",\n    default=8192,\n    type=int,\n    help=\"The buffer size of the classifier.\",\n)\n\n_il_group.add_argument(\n    \"--gamma-comp\",\n    default=0.1,\n    type=float,\n    help=\"The regularization of the linear classifier in compensation stream (DS-AL only)\",\n)\n\n_il_group.add_argument(\n    \"--sigma\",\n    default=10,\n    type=float,\n    help=\"The width-adjusting of the Gaussian kernel (GKEAL only)\",\n)\n\n_il_group.add_argument(\n    \"-C\",\n    \"--compensation-ratio\",\n    default=1,\n    type=float,\n    help=\"The regularization of the linear classifier in compensation stream (DS-AL only)\",\n)\n\n\ndef load_args() -> Dict[str, Any]:\n    global _parser\n    args = vars(_parser.parse_args())\n    args[\"data_root\"] = path.expanduser(args[\"data_root\"])\n    if args[\"cache_path\"] is not None:\n        assert path.isdir(args[\"cache_path\"]), \"The cache path is not a directory.\"\n        args[\"backbone_path\"] = path.join(args[\"cache_path\"], \"backbone.pth\")\n        assert path.isfile(args[\"backbone_path\"]), \\\n            f\"Backbone file \\\"{args['backbone_path']}\\\" doesn't exist.\"\n    saving_root = path.join(\n        \"saved_models\",\n        f\"{args['backbone']}_{args['dataset']}_{args['base_ratio']}_{args['dataset_seed']}\",\n    )\n    args[\"exp_name\"] = args[\"exp_name\"].strip()\n    if args[\"exp_name\"] == \"\":\n        args[\"exp_name\"] = args[\"method\"]\n    saving_root = path.join(saving_root, args[\"exp_name\"])\n\n    if args[\"IL_batch_size\"] is None:\n        args[\"IL_batch_size\"] = args[\"batch_size\"]\n\n    # Windows does not support \":\" in the path\n    current_time = datetime.now().isoformat(timespec=\"seconds\").replace(\":\", \"-\")\n    saving_root = path.join(saving_root, current_time)\n    args[\"saving_root\"] = saving_root\n    args[\"argv\"] = str(argv)\n    makedirs(saving_root, exist_ok=True)\n    with open(path.join(saving_root, \"args.yaml\"), \"w\", encoding=\"utf-8\") as yaml_file:\n        yaml.safe_dump(args, yaml_file)\n    args[\"data_root\"] = path.join(args[\"data_root\"], args[\"dataset\"])\n    return args\n\n\nif __name__ == \"__main__\":\n    print(load_args())\n"
  },
  {
    "path": "datasets/CIFAR.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nfrom torch import Tensor\nfrom torchvision.datasets import CIFAR10, CIFAR100\nfrom torchvision.transforms import v2 as transforms\nfrom typing import Tuple\nfrom .DatasetWrapper import DatasetWrapper\n\n\nclass CIFAR10_(DatasetWrapper[Tuple[Tensor, int]]):\n    num_classes = 10\n    mean = (0.49139967861519607843, 0.48215840839460784314, 0.44653091444546568627)\n    std = (0.21117028181572183225, 0.20857934290628859220, 0.21205155387102001073)\n    basic_transform = transforms.Compose(\n        [\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.ToPureTensor(),\n        ]\n    )\n    augment_transform = transforms.Compose(\n        [\n            transforms.RandomCrop(32, 4),\n            transforms.RandomHorizontalFlip(),\n            transforms.TrivialAugmentWide(\n                interpolation=transforms.InterpolationMode.BILINEAR\n            ),\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.ToPureTensor(),\n        ]\n    )\n\n    def __init__(\n        self,\n        root: str,\n        train: bool,\n        base_ratio: float,\n        num_phases: int,\n        augment: bool = False,\n        inplace_repeat: int = 1,\n        shuffle_seed: int | None = None,\n    ) -> None:\n        self.dataset = CIFAR10(root, train=train, download=True)\n        super().__init__(\n            self.dataset.targets,\n            base_ratio,\n            num_phases,\n            augment,\n            inplace_repeat,\n            shuffle_seed,\n        )\n\n\nclass CIFAR100_(DatasetWrapper[Tuple[Tensor, int]]):\n    num_classes = 100\n    mean = (0.50707515923713235294, 0.48654887331495098039, 0.44091784336703431373)\n    std = (0.26733428848992695514, 0.25643846542136995765, 0.27615047402246589731)\n    # std = (0.21103932286924015314, 0.20837755491382136483, 0.21551368222930648019)\n    basic_transform = transforms.Compose(\n        [\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.ToPureTensor(),\n        ]\n    )\n    augment_transform = transforms.Compose(\n        [\n            transforms.RandomCrop(32, 4),\n            transforms.RandomHorizontalFlip(),\n            transforms.TrivialAugmentWide(\n                interpolation=transforms.InterpolationMode.BILINEAR\n            ),\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.ToPureTensor(),\n        ]\n    )\n\n    def __init__(\n        self,\n        root: str,\n        train: bool,\n        base_ratio: float,\n        num_phases: int,\n        augment: bool = False,\n        inplace_repeat: int = 1,\n        shuffle_seed: int | None = None,\n    ) -> None:\n        self.dataset = CIFAR100(root, train=train, download=True)\n        super().__init__(\n            self.dataset.targets,\n            base_ratio,\n            num_phases,\n            augment,\n            inplace_repeat,\n            shuffle_seed,\n        )\n\n\nif __name__ == \"__main__\":\n    dataset_train = CIFAR100_(\n        \"~/.dataset\", train=True, base_ratio=0.1, num_phases=3, augment=True\n    )\n    dataset_test = CIFAR100_(\n        \"~/.dataset\", train=False, base_ratio=0.1, num_phases=3, augment=False\n    )\n\n    for X, y in dataset_train.subset_at_phase(0):\n        assert X.shape == (3, 32, 32)\n    for X, y in dataset_test.subset_at_phase(0):\n        assert X.shape == (3, 32, 32)\n    print(\"test passed\")\n"
  },
  {
    "path": "datasets/DatasetWrapper.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom typing import Callable, Iterable, Optional\nfrom torch.utils.data import Dataset, Subset\ntry:\n    from torch.utils.data.dataset import T_co\nexcept ImportError:\n    from torch.utils._ordered_set import T_co\nfrom abc import ABCMeta\nfrom random import Random\nfrom numpy import repeat\nfrom itertools import chain\n\n\nclass DatasetWrapper(Dataset[T_co], metaclass=ABCMeta):\n    basic_transform: Callable[[T_co], T_co]\n    augment_transform: Callable[[T_co], T_co]\n\n    def __init__(\n        self,\n        labels: Iterable[int],\n        base_ratio: float,\n        num_phases: int,\n        augment: bool,\n        inplace_repeat: int = 1,\n        shuffle_seed: Optional[int] = None,\n    ) -> None:\n        # Type hints\n        self.dataset: Dataset[T_co]\n        self.num_classes: int\n\n        # Initialization\n        super().__init__()\n        self.inplace_repeat = inplace_repeat\n        self.base_ratio = base_ratio\n        self.num_phases = num_phases\n        self.base_size = int(self.num_classes * self.base_ratio)\n        self.incremental_size = self.num_classes - self.base_size\n        self.phase_size = self.incremental_size // num_phases if num_phases > 0 else 0\n        # Create a list of indices for each class\n        self.class_indices: list[list[int]] = [[] for _ in range(self.num_classes)]\n        for idx, label in enumerate(labels):\n            self.class_indices[label].append(idx)\n        self._transform = self.augment_transform if augment else self.basic_transform\n        # Shuffle the class indices\n        self.real_labels: list[int] = list(range(self.num_classes))\n        if shuffle_seed is not None:\n            Random(shuffle_seed).shuffle(self.real_labels)\n            Random(shuffle_seed).shuffle(self.class_indices)\n\n    def __getitem__(self, index: int) -> T_co:\n        return self._transform(self.dataset[index])\n\n    def _subset(self, label_begin: int, label_end: int) -> Subset[T_co]:\n        sub_ids = tuple(chain.from_iterable(self.class_indices[label_begin:label_end]))\n        return Subset(self, repeat(sub_ids, self.inplace_repeat).tolist())\n\n    def subset_at_phase(self, phase: int) -> Subset[T_co]:\n        if phase == 0:\n            return self._subset(0, self.base_size)\n        return self._subset(\n            self.base_size + (phase - 1) * self.phase_size,\n            self.base_size + phase * self.phase_size,\n        )\n\n    def subset_until_phase(self, phase: int) -> Subset[T_co]:\n        return self._subset(\n            0,\n            self.base_size + phase * self.phase_size,\n        )\n"
  },
  {
    "path": "datasets/Features.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nfrom os import path\nfrom .DatasetWrapper import DatasetWrapper\nfrom torch.utils.data import TensorDataset\nfrom torchvision.transforms import v2 as transforms\n\n\nclass Features(DatasetWrapper[tuple[torch.Tensor, torch.LongTensor]]):\n    basic_transform = transforms.Identity()\n    augment_transform = transforms.Identity()\n\n    def __init__(\n        self,\n        root: str,\n        train: bool,\n        base_ratio: float,\n        num_phases: int,\n        augment: bool = False,\n        inplace_repeat: int = 1,\n        shuffle_seed: int | None = None,\n    ) -> None:\n        assert augment == False, \"Augmentation is not supported for Features dataset\"\n\n        if train:\n            X: torch.Tensor = torch.load(path.join(root, \"X_train.pt\"), weights_only=True)\n            y: torch.Tensor = torch.load(path.join(root, \"y_train.pt\"), weights_only=True)\n        else:\n            X: torch.Tensor = torch.load(path.join(root, \"X_test.pt\"), weights_only=True)\n            y: torch.Tensor = torch.load(path.join(root, \"y_test.pt\"), weights_only=True)\n\n        y = y.to(torch.long, non_blocking=True)\n        self.dataset = TensorDataset(X, y)  # type: ignore\n        self.num_classes = int(y.max().item()) + 1\n\n        super().__init__(\n            y.numpy().tolist(),\n            base_ratio,\n            num_phases,\n            False,\n            inplace_repeat,\n            shuffle_seed,\n        )\n"
  },
  {
    "path": "datasets/ImageNet.py",
    "content": "# -*- coding: utf-8 -*-\nfrom typing import Tuple\nimport torch\nfrom .DatasetWrapper import DatasetWrapper\nfrom torchvision.datasets import ImageNet\nfrom torchvision.transforms import v2 as transforms\nfrom os import path\n\n\nclass ImageNet_(DatasetWrapper[Tuple[torch.Tensor, int]]):\n    num_classes = 1000\n    mean = (0.485, 0.456, 0.406)\n    std = (0.229, 0.224, 0.225)\n\n    basic_transform = transforms.Compose(\n        [\n            transforms.Resize(232),\n            transforms.CenterCrop(224),\n            transforms.PILToTensor(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.ToPureTensor(),\n        ]\n    )\n\n    augment_transform = transforms.Compose(\n        [\n            transforms.RandomResizedCrop(176),\n            transforms.RandomHorizontalFlip(0.5),\n            transforms.TrivialAugmentWide(\n                interpolation=transforms.InterpolationMode.BILINEAR\n            ),\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n            transforms.RandomErasing(0.1),\n            transforms.ToPureTensor(),\n        ]\n    )\n\n    def __init__(\n        self,\n        root: str,\n        train: bool,\n        base_ratio: float,\n        num_phases: int,\n        augment: bool = False,\n        inplace_repeat: int = 1,\n        shuffle_seed: int | None = None,\n    ) -> None:\n        root = path.expanduser(root)\n        self.dataset = ImageNet(root, split=\"train\" if train else \"val\")\n        super().__init__(\n            self.dataset.targets,\n            base_ratio,\n            num_phases,\n            augment,\n            inplace_repeat,\n            shuffle_seed,\n        )\n"
  },
  {
    "path": "datasets/MNIST.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nfrom torch import Tensor\nfrom torchvision.datasets import MNIST\nfrom torchvision.transforms import v2 as transforms\nfrom .DatasetWrapper import DatasetWrapper\n\n\nclass MNIST_(DatasetWrapper[tuple[Tensor, int]]):\n    num_classes = 10\n    mean = (0.13066047627384287048,)\n    std = (0.30524474224261827502,)\n\n    basic_transform = transforms.Compose(\n        [\n            transforms.ToImage(),\n            transforms.ToDtype(torch.float32, scale=True),\n            transforms.Normalize(mean, std, inplace=True),\n        ]\n    )\n    augment_transform = basic_transform\n\n    def __init__(\n        self,\n        root: str,\n        train: bool,\n        base_ratio: float,\n        num_phases: int,\n        augment: bool = False,\n        inplace_repeat: int = 1,\n        shuffle_seed: int | None = None,\n    ) -> None:\n        self.dataset = MNIST(root, train=train, download=True)\n        super().__init__(\n            self.dataset.targets.tolist(),\n            base_ratio,\n            num_phases,\n            augment,\n            inplace_repeat,\n            shuffle_seed,\n        )\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .DatasetWrapper import DatasetWrapper\nfrom .MNIST import MNIST_ as MNIST\nfrom .CIFAR import CIFAR10_ as CIFAR10\nfrom .CIFAR import CIFAR100_ as CIFAR100\nfrom .ImageNet import ImageNet_ as ImageNet\nfrom typing import Union\nfrom .Features import Features\n\n\n__all__ = [\n    \"load_dataset\",\n    \"dataset_list\",\n    \"MNIST\",\n    \"CIFAR10\",\n    \"CIFAR100\",\n    \"ImageNet\",\n    \"DatasetWrapper\",\n    \"Features\",\n]\n\ndataset_list = {\n    \"MNIST\": MNIST,\n    \"CIFAR-10\": CIFAR10,\n    \"CIFAR-100\": CIFAR100,\n    \"ImageNet-1k\": ImageNet,\n}\n\n\ndef load_dataset(\n    name: str,\n    root: str,\n    train: bool,\n    base_ratio: float,\n    num_phases: int,\n    augment: bool = False,\n    inplace_repeat: int = 1,\n    shuffle_seed: int | None = None,\n    *args,\n    **kwargs\n) -> Union[MNIST, CIFAR10, CIFAR100, ImageNet]:\n    return dataset_list[name](\n        root=root,\n        train=train,\n        base_ratio=base_ratio,\n        num_phases=num_phases,\n        augment=augment,\n        inplace_repeat=inplace_repeat,\n        shuffle_seed=shuffle_seed,\n        *args,\n        **kwargs\n    )\n"
  },
  {
    "path": "environment.yaml",
    "content": "# Usage: conda env create -f environment.yaml\n\nname: AL\n\nchannels:\n  - pytorch\n  - nvidia\n  - conda-forge\n\ndependencies:\n  - python=3.11\n\n  # PyTorch\n  - pytorch>=2.2\n  - torchvision\n  - pytorch-cuda  # For Nvidia GPU\n\n  # Necessary Utils\n  - numpy\n  - tqdm\n  - scikit-learn\n  - pip\n  - pip:\n    - prefetch_generator\n\n  # Optional\n  - black\n  - mypy\n"
  },
  {
    "path": "main.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nfrom os import path\nfrom tqdm import tqdm\nfrom config import load_args, ALL_METHODS\nfrom models import load_backbone\nfrom typing import Any, Dict, List, Tuple, Optional\nfrom datasets import Features, load_dataset\nfrom utils import set_determinism, validate\nfrom torch._prims_common import DeviceLikeType\nfrom torch.utils.data import Dataset, DataLoader\n\n\ndef make_dataloader(\n    dataset: Dataset,\n    shuffle: bool = False,\n    batch_size: int = 256,\n    num_workers: int = 8,\n    device: Optional[DeviceLikeType] = None,\n    persistent_workers: bool = False,\n) -> DataLoader:\n    pin_memory = (device is not None) and (torch.device(device).type == \"cuda\")\n    config = {\n        \"batch_size\": batch_size,\n        \"shuffle\": shuffle,\n        \"num_workers\": num_workers,\n        \"pin_memory\": pin_memory,\n        \"pin_memory_device\": str(device) if pin_memory else \"\",\n        \"persistent_workers\": persistent_workers,\n    }\n    try:\n        from prefetch_generator import BackgroundGenerator\n\n        class DataLoaderX(DataLoader):\n            def __iter__(self):\n                return BackgroundGenerator(super().__iter__())\n\n        return DataLoaderX(dataset, **config)\n    except ImportError:\n        return DataLoader(dataset, **config)\n\n\ndef check_cache_features(root: str) -> bool:\n    files_list = [\"X_train.pt\", \"y_train.pt\", \"X_test.pt\", \"y_test.pt\"]\n    for file in files_list:\n        if not path.isfile(path.join(root, file)):\n            return False\n    return True\n\n\n@torch.no_grad()\ndef cache_features(\n    backbone: torch.nn.Module,\n    dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]],\n    device: Optional[DeviceLikeType] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    backbone.eval()\n    X_all: List[torch.Tensor] = []\n    y_all: List[torch.Tensor] = []\n    for X, y in tqdm(dataloader, \"Caching\"):\n        X: torch.Tensor = backbone(X.to(device))\n        y: torch.Tensor = y.to(torch.int16, non_blocking=True)\n        X_all.append(X.cpu())\n        y_all.append(y.cpu())\n    return torch.cat(X_all), torch.cat(y_all)\n\n\ndef main(args: Dict[str, Any]):\n    backbone_name = args[\"backbone\"]\n\n    # Select device\n    if args[\"cpu_only\"] or not torch.cuda.is_available():\n        main_device = torch.device(\"cpu\")\n        all_gpus = None\n    elif args[\"gpus\"] is not None:\n        gpus = args[\"gpus\"]\n        main_device = torch.device(f\"cuda:{gpus[0]}\")\n        all_gpus = [torch.device(f\"cuda:{gpu}\") for gpu in gpus]\n    else:\n        main_device = torch.device(\"cuda:0\")\n        all_gpus = None\n\n    if args[\"seed\"] is not None:\n        set_determinism(args[\"seed\"])\n\n    if \"backbone_path\" in args:\n        assert path.isfile(\n            args[\"backbone_path\"]\n        ), f\"Backbone file \\\"{args['backbone_path']}\\\" doesn't exist.\"\n        preload_backbone = True\n        backbone, _, feature_size = torch.load(\n            args[\"backbone_path\"], map_location=main_device, weights_only=False\n        )\n    else:\n        # Load model pre-train on ImageNet if there is no base training dataset.\n        preload_backbone = False\n        load_pretrain = args[\"base_ratio\"] == 0 or \"ImageNet\" not in args[\"dataset\"]\n        backbone, _, feature_size = load_backbone(backbone_name, pretrain=load_pretrain)\n        if load_pretrain:\n            assert args[\"dataset\"] != \"ImageNet\", \"Data may leak!!!\"\n    backbone = backbone.to(main_device, non_blocking=True)\n\n    dataset_args = {\n        \"name\": args[\"dataset\"],\n        \"root\": args[\"data_root\"],\n        \"base_ratio\": args[\"base_ratio\"],\n        \"num_phases\": args[\"phases\"],\n        \"shuffle_seed\": args[\"dataset_seed\"] if \"dataset_seed\" in args else None,\n    }\n    dataset_train = load_dataset(train=True, augment=True, **dataset_args)\n    dataset_test = load_dataset(train=False, augment=False, **dataset_args)\n\n    # Select algorithm\n    assert args[\"method\"] in ALL_METHODS, f\"Unknown method: {args['method']}\"\n    learner = ALL_METHODS[args[\"method\"]](\n        args, backbone, feature_size, main_device, all_devices=all_gpus\n    )\n\n    # Base training\n    if args[\"base_ratio\"] > 0 and not preload_backbone:\n        train_subset = dataset_train.subset_at_phase(0)\n        test_subset = dataset_test.subset_at_phase(0)\n        train_loader = make_dataloader(\n            train_subset,\n            True,\n            args[\"batch_size\"],\n            args[\"num_workers\"],\n            device=main_device,\n        )\n        test_loader = make_dataloader(\n            test_subset,\n            False,\n            args[\"batch_size\"],\n            args[\"num_workers\"],\n            device=main_device,\n        )\n        learner.base_training(\n            train_loader,\n            test_loader,\n            dataset_train.base_size,\n        )\n\n    # Load dataset\n    if args[\"cache_features\"]:\n        if \"cache_path\" not in args or args[\"cache_path\"] is None:\n            args[\"cache_path\"] = args[\"saving_root\"]\n        if not check_cache_features(args[\"cache_path\"]):\n            backbone = learner.backbone.eval()\n            dataset_train = load_dataset(\n                args[\"dataset\"], args[\"data_root\"], True, 1, 0, augment=False\n            )\n            dataset_test = load_dataset(\n                args[\"dataset\"], args[\"data_root\"], False, 1, 0, augment=False\n            )\n            train_loader = make_dataloader(\n                dataset_train.subset_at_phase(0),\n                False,\n                args[\"batch_size\"],\n                args[\"num_workers\"],\n                device=main_device,\n            )\n            test_loader = make_dataloader(\n                dataset_test.subset_at_phase(0),\n                False,\n                args[\"batch_size\"],\n                args[\"num_workers\"],\n                device=main_device,\n            )\n\n            if all_gpus is not None and len(all_gpus) > 1:\n                backbone = torch.nn.DataParallel(backbone, device_ids=all_gpus)\n            X_train, y_train = cache_features(\n                backbone, train_loader, device=main_device\n            )\n            X_test, y_test = cache_features(backbone, test_loader, device=main_device)\n            torch.save(X_train, path.join(args[\"cache_path\"], \"X_train.pt\"))\n            torch.save(y_train, path.join(args[\"cache_path\"], \"y_train.pt\"))\n            torch.save(X_test, path.join(args[\"cache_path\"], \"X_test.pt\"))\n            torch.save(y_test, path.join(args[\"cache_path\"], \"y_test.pt\"))\n        dataset_train = Features(\n            args[\"cache_path\"],\n            train=True,\n            base_ratio=args[\"base_ratio\"],\n            num_phases=args[\"phases\"],\n            augment=False,\n        )\n        dataset_test = Features(\n            args[\"cache_path\"],\n            train=False,\n            base_ratio=args[\"base_ratio\"],\n            num_phases=args[\"phases\"],\n            augment=False,\n        )\n        learner.backbone = torch.nn.Identity()\n        learner.model.backbone = torch.nn.Identity()\n    else:\n        dataset_train = load_dataset(train=True, augment=False, **dataset_args)\n        dataset_test = load_dataset(train=False, augment=False, **dataset_args)\n\n    # Incremental learning\n    sum_acc = 0\n    log_file_path = path.join(args[\"saving_root\"], \"IL.csv\")\n    log_file = open(log_file_path, \"w\", buffering=1)\n    print(\n        \"phase\", \"acc@avg\", \"acc@1\", \"acc@5\", \"f1-micro\", \"loss\", file=log_file, sep=\",\"\n    )\n\n    for phase in range(0, args[\"phases\"] + 1):\n        train_subset = dataset_train.subset_at_phase(phase)\n        test_subset = dataset_test.subset_until_phase(phase)\n        train_loader = make_dataloader(\n            train_subset,\n            True,\n            args[\"IL_batch_size\"],\n            args[\"num_workers\"],\n            device=main_device,\n        )\n        test_loader = make_dataloader(\n            test_subset,\n            False,\n            args[\"IL_batch_size\"],\n            args[\"num_workers\"],\n            device=main_device,\n        )\n        if phase == 0:\n            learner.learn(train_loader, dataset_train.base_size, \"Re-align\")\n        else:\n            learner.learn(train_loader, dataset_train.phase_size)\n        learner.before_validation()\n\n        # Validation\n        val_meter = validate(\n            learner,\n            test_loader,\n            dataset_train.num_classes,\n            desc=f\"Phase {phase}\",\n        )\n        sum_acc += val_meter.accuracy\n        print(\n            f\"loss: {val_meter.loss:.4f}\",\n            f\"acc@1: {val_meter.accuracy * 100:.3f}%\",\n            f\"acc@5: {val_meter.accuracy5 * 100:.3f}%\",\n            f\"f1-micro: {val_meter.f1_micro * 100:.3f}%\",\n            f\"acc@avg: {sum_acc / (phase + 1) * 100:.3f}%\",\n            sep=\"    \",\n        )\n        print(\n            phase,\n            sum_acc / (phase + 1),\n            val_meter.accuracy,\n            val_meter.accuracy5,\n            val_meter.f1_micro,\n            val_meter.loss,\n            file=log_file,\n            sep=\",\",\n        )\n    log_file.close()\n\n\nif __name__ == \"__main__\":\n    main(load_args())\n"
  },
  {
    "path": "models/CifarResNet.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nProperly implemented ResNet-s for CIFAR10 as described in paper [1].\n\nThe implementation and structure of this file is hugely influenced by [2]\nwhich is implemented for ImageNet and doesn't have option A for identity.\nMoreover, most of the implementations on the web is copy-paste from\ntorchvision's resnet and has wrong number of params.\n\nProper ResNet-s for CIFAR10 (for fair comparison and etc.) has following\nnumber of layers and parameters:\n\nname      | layers | params\nResNet20  |    20  | 0.27M\nResNet32  |    32  | 0.46M\nResNet44  |    44  | 0.66M\nResNet56  |    56  | 0.85M\nResNet110 |   110  |  1.7M\nResNet1202|  1202  | 19.4m\n\nwhich this implementation indeed has.\n\nReference:\n[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n    Deep Residual Learning for Image Recognition. arXiv:1512.03385\n[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py\n\nIf you use this implementation in you work, please don't forget to mention the\nauthor, Yerlan Idelbayev.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\n\n\n__all__ = [\n    \"CifarResNet\",\n    \"resnet20\",\n    \"resnet32\",\n    \"resnet44\",\n    \"resnet56\",\n    \"resnet110\",\n    \"resnet1202\",\n]\n\n\ndef _weights_init(m):\n    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):\n        init.kaiming_normal_(m.weight)\n\n\nclass ShortcutA(nn.Module):\n    def __init__(self, planes) -> None:\n        super().__init__()\n        self.pad = planes // 4\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.pad(\n            x[:, :, ::2, ::2],\n            (0, 0, 0, 0, self.pad, self.pad),\n            \"constant\",\n            0,\n        )\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_planes, planes, stride=1, option=\"A\"):\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(\n            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False\n        )\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes, planes, kernel_size=3, stride=1, padding=1, bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != planes:\n            if option == \"A\":\n                \"\"\"\n                For CIFAR10 ResNet paper uses option A.\n                \"\"\"\n                self.shortcut = ShortcutA(planes)\n            elif option == \"B\":\n                self.shortcut = nn.Sequential(\n                    nn.Conv2d(\n                        in_planes,\n                        self.expansion * planes,\n                        kernel_size=1,\n                        stride=stride,\n                        bias=False,\n                    ),\n                    nn.BatchNorm2d(self.expansion * planes),\n                )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.bn2(self.conv2(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass CifarResNet(nn.Module):\n    def __init__(self, block, num_blocks, num_classes=10):\n        super(CifarResNet, self).__init__()\n        self.in_planes = 16\n\n        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(16)\n        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)\n        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)\n        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)\n        self.fc = nn.Linear(64, num_classes)\n\n        self.apply(_weights_init)\n\n    def _make_layer(self, block, planes, num_blocks, stride):\n        strides = [stride] + [1] * (num_blocks - 1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride))\n            self.in_planes = planes * block.expansion\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = F.avg_pool2d(out, out.size()[3])\n        out = out.view(out.size(0), -1)\n        out = self.fc(out)\n        return out\n\n\ndef resnet20(num_classes=10):\n    return CifarResNet(BasicBlock, [3, 3, 3], num_classes)\n\n\ndef resnet32(num_classes=10):\n    return CifarResNet(BasicBlock, [5, 5, 5], num_classes)\n\n\ndef resnet44(num_classes=10):\n    return CifarResNet(BasicBlock, [7, 7, 7], num_classes)\n\n\ndef resnet56(num_classes=10):\n    return CifarResNet(BasicBlock, [9, 9, 9], num_classes)\n\n\ndef resnet110(num_classes=10):\n    return CifarResNet(BasicBlock, [18, 18, 18], num_classes)\n\n\ndef resnet1202(num_classes=10):\n    return CifarResNet(BasicBlock, [200, 200, 200], num_classes)\n\n\ndef calc_num_params(model: torch.nn.Module) -> int:\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\nif __name__ == \"__main__\":\n    print(calc_num_params(resnet32()))\n"
  },
  {
    "path": "models/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nfrom typing import Dict, Tuple, Union, Optional, Callable\n\nfrom torchvision.models import WeightsEnum\nfrom torch.nn import Flatten\n\nfrom torchvision.models.resnet import (\n    ResNet,\n    resnet18,\n    resnet34,\n    resnet50,\n    resnet101,\n    resnet152,\n    ResNet18_Weights,\n    ResNet34_Weights,\n    ResNet50_Weights,\n    ResNet101_Weights,\n    ResNet152_Weights,\n)\n\nfrom .CifarResNet import (\n    CifarResNet,\n    resnet20,\n    resnet32,\n    resnet44,\n    resnet56,\n    resnet110,\n    resnet1202,\n)\n\nfrom torchvision.models.vision_transformer import (\n    VisionTransformer,\n    vit_b_16,\n    vit_b_32,\n    vit_l_16,\n    vit_l_32,\n    vit_h_14,\n    ViT_B_16_Weights,\n    ViT_B_32_Weights,\n    ViT_L_16_Weights,\n    ViT_L_32_Weights,\n    ViT_H_14_Weights,\n)\n\n__all__ = [\n    \"ResNet\",\n    \"resnet18\",\n    \"resnet34\",\n    \"resnet50\",\n    \"resnet101\",\n    \"resnet152\",\n    \"CifarResNet\",\n    \"resnet20\",\n    \"resnet32\",\n    \"resnet44\",\n    \"resnet56\",\n    \"resnet110\",\n    \"resnet1202\",\n    \"VisionTransformer\",\n    \"vit_b_16\",\n    \"vit_b_32\",\n    \"vit_l_16\",\n    \"vit_l_32\",\n    \"vit_h_14\",\n    \"load_backbone\",\n]\n\n# fmt: off\nmodels: Dict[str, Tuple[\n        int,    # Input image size\n        Callable[[], Union[CifarResNet, ResNet, VisionTransformer, Flatten]], # Model constructor\n        Optional[WeightsEnum]\n    ]\n] = {\n    # MNIST: No backbone\n    \"Flatten\": (28, Flatten, None),\n    # ResNet for CIFAR\n    \"resnet20\":   (32, resnet20  , None),\n    \"resnet32\":   (32, resnet32  , None),\n    \"resnet44\":   (32, resnet44  , None),\n    \"resnet56\":   (32, resnet56  , None),\n    \"resnet110\":  (32, resnet110 , None),\n    \"resnet1202\": (32, resnet1202, None),\n    # ResNet for ImageNet\n    \"resnet18\":  (224, resnet18,  ResNet18_Weights.DEFAULT ),\n    \"resnet34\":  (224, resnet34,  ResNet34_Weights.DEFAULT ),\n    \"resnet50\":  (224, resnet50,  ResNet50_Weights.DEFAULT ),\n    \"resnet101\": (224, resnet101, ResNet101_Weights.DEFAULT),\n    \"resnet152\": (224, resnet152, ResNet152_Weights.DEFAULT),\n    # Vision Transformer for ImageNet\n    \"vit_b_16\": (384, vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1),\n    \"vit_b_32\": (224, vit_b_32, ViT_B_32_Weights.IMAGENET1K_V1         ),\n    \"vit_l_16\": (512, vit_l_16, ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1),\n    \"vit_l_32\": (224, vit_l_32, ViT_L_32_Weights.IMAGENET1K_V1         ),\n    \"vit_h_14\": (518, vit_h_14, ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1),\n}\n# fmt: on\n\n\ndef load_backbone(\n    name: str, pretrain: bool = False, *args, **kwargs\n) -> Tuple[torch.nn.Module, int, int]:\n    input_img_size, model, weights = models[name]\n    if pretrain and (weights is not None) and (\"weights\" not in kwargs):\n        kwargs[\"weights\"] = weights\n    backbone = model(*args, **kwargs)\n\n    if isinstance(backbone, VisionTransformer):\n        feature_size: int = backbone.heads[-1].in_features\n        backbone.heads = torch.nn.Identity()  # type: ignore\n    elif isinstance(backbone, (ResNet, CifarResNet)):\n        feature_size = backbone.fc.in_features\n        backbone.fc = torch.nn.Identity()  # type: ignore\n    elif isinstance(backbone, Flatten):\n        feature_size = input_img_size ** 2\n    return backbone, input_img_size, feature_size\n\n\nif __name__ == \"__main__\":\n    for name in models.keys():\n        backbone, input_img_size, feature_size = load_backbone(name, pretrain=True)\n        test_img = torch.randn((1, 3, input_img_size, input_img_size))\n        prototype: torch.Tensor = backbone(test_img)\n        assert len(prototype.shape) == 2 and prototype.shape[0] == 1\n        assert feature_size == prototype.shape[1]\n"
  },
  {
    "path": "utils/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .validate import validate\nfrom .set_weight_decay import set_weight_decay\nfrom .set_determinism import set_determinism\nfrom .metrics import ClassificationMeter\n\n__all__ = [\"validate\", \"set_weight_decay\", \"set_determinism\", \"ClassificationMeter\"]\n"
  },
  {
    "path": "utils/metrics.py",
    "content": "import torch\nimport numpy as np\nfrom sklearn import metrics\n\n\nclass ClassificationMeter:\n    def __init__(self, num_classes: int, record_logits: bool = False) -> None:\n        self.num_classes = num_classes\n        self.total_loss = 0.0\n        self.labels = np.zeros((0,), dtype=np.int32)\n        self.prediction = np.zeros((0,), dtype=np.int32)\n        self.acc5_cnt = 0\n        self.record_logits = record_logits\n        if self.record_logits:\n            self.logits = np.ndarray((0, num_classes))\n\n    def record(self, y_true: torch.Tensor, logits: torch.Tensor) -> None:\n        self.labels = np.concatenate([self.labels, y_true.cpu().numpy()])\n        # Record logits\n        if self.record_logits:\n            logits_softmax = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()\n            self.logits = np.concatenate([self.logits, logits_softmax])\n\n        # Loss\n        self.total_loss += float(\n            torch.nn.functional.cross_entropy(logits, y_true, reduction=\"sum\").item()\n        )\n        # Top-5 accuracy\n        y_pred = logits.topk(5, largest=True).indices.to(torch.int)\n        acc5_judge = (y_pred == y_true[:, None]).any(dim=-1)\n        self.acc5_cnt += int(acc5_judge.sum().item())\n\n        # Record the predictions\n        self.prediction = np.concatenate([self.prediction, y_pred[:, 0].cpu().numpy()])\n\n    @property\n    def accuracy(self) -> float:\n        return float(metrics.accuracy_score(self.labels, self.prediction))\n\n    @property\n    def balanced_accuracy(self) -> float:\n        result = metrics.balanced_accuracy_score(\n            self.labels, self.prediction, adjusted=True\n        )\n        return float(result)\n\n    @property\n    def f1_micro(self) -> float:\n        result = metrics.f1_score(self.labels, self.prediction, average=\"micro\")\n        return float(result)\n\n    @property\n    def f1_macro(self) -> float:\n        result = metrics.f1_score(self.labels, self.prediction, average=\"macro\")\n        return float(result)\n\n    @property\n    def accuracy5(self) -> float:\n        return self.acc5_cnt / len(self.labels)\n\n    @property\n    def loss(self) -> float:\n        return float(self.total_loss / len(self.labels))\n"
  },
  {
    "path": "utils/set_determinism.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport torch\nimport numpy\nfrom os import environ\nimport random\n\n\ndef set_determinism(seed: int) -> None:\n    environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n    torch.use_deterministic_algorithms(True)\n    torch.manual_seed(seed)\n    numpy.random.seed(seed)\n    random.seed(seed)\n    if torch.cuda.is_available():\n        torch.backends.cudnn.benchmark = False\n        torch.backends.cudnn.deterministic = True\n        torch.cuda.manual_seed_all(seed)\n"
  },
  {
    "path": "utils/set_weight_decay.py",
    "content": "# -*- coding: utf-8 -*-\nfrom gc import disable\nimport torch\nfrom typing import Dict, List, Any\n\n\ndef set_weight_decay(\n    model: torch.nn.Module,\n    weight_decay: float,\n    disable_norm_decay: bool = True,\n    disable_bias_decay: bool = True,\n    disable_embedding_decay: bool = True,\n):\n    # See: https://github.com/pytorch/vision/blob/main/references/classification/utils.py\n    norm_classes = (\n        torch.nn.modules.batchnorm._BatchNorm,\n        torch.nn.LayerNorm,\n        torch.nn.GroupNorm,\n        torch.nn.modules.instancenorm._InstanceNorm,\n        torch.nn.LocalResponseNorm,\n    )\n\n    params = {\n        \"other\": [],\n        \"norm\": [],\n        \"bias\": [],\n        \"class_token\": [],\n        \"position_embedding\": [],\n        \"relative_position_bias_table\": [],\n    }\n\n    params_weight_decay = {\n        \"bias\": 0 if disable_bias_decay else weight_decay,\n        \"class_token\": 0 if disable_embedding_decay else weight_decay,\n        \"position_embedding\": 0 if disable_embedding_decay else weight_decay,\n        \"relative_position_bias_table\": 0 if disable_embedding_decay else weight_decay,\n    }\n\n    def _add_params(module: torch.nn.Module, prefix=\"\"):\n        for name, p in module.named_parameters(recurse=False):\n            for key in params_weight_decay.keys():\n                target_name = (\n                    f\"{prefix}.{name}\" if prefix != \"\" and \".\" in key else name\n                )\n                if key == target_name:\n                    params[key].append(p)\n                    break\n            else:\n                if isinstance(module, norm_classes):\n                    params[\"norm\"].append(p)\n                else:\n                    params[\"other\"].append(p)\n\n        for child_name, child_module in module.named_children():\n            child_prefix = f\"{prefix}.{child_name}\" if prefix != \"\" else child_name\n            _add_params(child_module, prefix=child_prefix)\n\n    _add_params(model)\n\n    params_weight_decay[\"other\"] = weight_decay\n    params_weight_decay[\"norm\"] = 0.0 if disable_norm_decay else weight_decay\n\n    param_groups: List = []\n    for key in params:\n        if len(params[key]) > 0:\n            param_groups.append(\n                {\"params\": params[key], \"weight_decay\": params_weight_decay[key]}\n            )\n    return param_groups\n"
  },
  {
    "path": "utils/validate.py",
    "content": "# -*- coding: utf-8 -*-\nimport torch\nfrom tqdm import tqdm\nfrom typing import Tuple, Iterable, Optional, Callable\nfrom .metrics import ClassificationMeter\n\n\n@torch.no_grad()\ndef validate(\n    model: Callable[[torch.Tensor], torch.Tensor],\n    data_loader: Iterable[Tuple[torch.Tensor, torch.Tensor]],\n    num_classes: int,\n    desc: Optional[str] = None\n) -> ClassificationMeter:\n    if isinstance(model, torch.nn.Module):\n        model.eval()\n        device = next(model.parameters()).device\n    else:\n        device = model.device\n    meter = ClassificationMeter(num_classes)\n\n    for X, y in tqdm(data_loader, desc=desc):\n        X = X.to(device, non_blocking=True)\n        y = y.to(device, non_blocking=True)\n\n        # Calculate the loss\n        logits: torch.Tensor = model(X)\n        meter.record(y, logits)\n    return meter\n"
  }
]