Repository: ZHUANGHP/Analytic-continual-learning Branch: main Commit: 0afd6bf87622 Files: 29 Total size: 106.1 KB Directory structure: gitextract_vo_20iyw/ ├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── analytic/ │ ├── ACIL.py │ ├── AEFOCL.py │ ├── AIR.py │ ├── AnalyticLinear.py │ ├── Buffer.py │ ├── DSAL.py │ ├── GKEAL.py │ ├── Learner.py │ └── __init__.py ├── config.py ├── datasets/ │ ├── CIFAR.py │ ├── DatasetWrapper.py │ ├── Features.py │ ├── ImageNet.py │ ├── MNIST.py │ └── __init__.py ├── environment.yaml ├── main.py ├── models/ │ ├── CifarResNet.py │ └── __init__.py └── utils/ ├── __init__.py ├── metrics.py ├── set_determinism.py ├── set_weight_decay.py └── validate.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ saved_models/** .vscode/** *.pt *.pth backbones/** # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Huiping Zhuang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ [中文](README_CN.md) | English # Analytic Continual Learning Official implementation of the following papers. [1] Zhuang, Huiping, et al. "[ACIL: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection.](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4b74a42fc81fc7ee252f6bcb6e26c8be-Abstract-Conference.html)" Advances in Neural Information Processing Systems 35 (2022): 11602-11614. [2] Zhuang, Huiping, et al. "[GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.](https://openaccess.thecvf.com/content/CVPR2023/html/Zhuang_GKEAL_Gaussian_Kernel_Embedded_Analytic_Learning_for_Few-Shot_Class_Incremental_CVPR_2023_paper.html)" Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. [3] Zhuang, Huiping, et al. "[DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.](https://ojs.aaai.org/index.php/AAAI/article/view/29670)" Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. [4] Zhuang, Huiping, et al. "[GACL: Exemplar-Free Generalized Analytic Continual Learning](https://neurips.cc/virtual/2024/poster/95330)" Advances in Neural Information Processing Systems 37 (2024). [[OpenReview]](https://openreview.net/forum?id=P6aJ7BqYlc) [[arXiv]](https://arxiv.org/abs/2403.15706) [5] Zhuang, Huiping, et al. "[Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task.](https://ieeexplore.ieee.org/document/10721370)" IEEE Transactions on Vehicular Technology (2024). [6] Fang, Di, et al. "[AIR: Analytic Imbalance Rectifier for Continual Learning.](https://arxiv.org/abs/2408.10349)" arXiv preprint arXiv:2408.10349 (2024). ![](figures/acc_cmp.jpg) **Welcome to join our Tencent QQ group: [954528161](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=qaK4W8Jw6d--VWHlx7iUs93T2qMJT9k_&authKey=5e8hSXX8rALjM12iGwrZ9BmRBP9iUfCuRGNCaZ3%2Bx0msiRFVcSwu%2FuZpeKig1XQH&noverify=0&group_code=954528161). Chinese tutorial is available at [Bilibili](https://www.bilibili.com/video/BV1wq421A7YM/).** ## Dual Branch We have a dual branch at "[Analytic Federated Learning.](https://github.com/ZHUANGHP/Analytic-federated-learning)" ## Environment We recommend using the [Anaconda](https://anaconda.org/) to install the development environment. ```bash git clone --depth=1 git@github.com:ZHUANGHP/Analytic-continual-learning.git cd Analytic-continual-learning conda env create -f environment.yaml conda activate AL mkdir backbones ``` Download the pre-train weight at the [release page](https://github.com/ZHUANGHP/Analytic-continual-learning/releases) for quick start. We suggest you to extract the pre-train backbone (zip file) under the `backbones` folder. For the macOS users and the CPU-only users, you need deleted the items related to CUDA in the `environment.yaml` file. We highly recommend you to run our code in Linux. Windows and macOS users are also welcome to submit issues if they have problems running the code. ## Quick Start Put the base training weights (provided at the [release page](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)) at the `backbones` directory. Gradients are not used in continuous learning. **You can run our code even on CPUs.** Here are some examples. ```bash # ACIL (CIFAR-100, B50 25 phases) python main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # G-ACIL (CIFAR-100, B50 25 phases) python main.py G-ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # GKEAL (CIFAR-100, B50 10 phases) python main.py GKEAL --dataset CIFAR-100 --base-ratio 0.5 --phases 10 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --sigma 10 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # DS-AL (CIFAR-100, B50 50 phases) python main.py DS-AL --dataset CIFAR-100 --base-ratio 0.5 --phases 50 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 0.6 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # DS-AL (ImageNet-1k, B50 20 phases) python main.py DS-AL --dataset ImageNet-1k --base-ratio 0.5 --phases 20 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet18 \ --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 1.5 --buffer-size 16384 \ --cache-features --backbone-path ./backbones/resnet18_ImageNet-1k_0.5_None ``` ## Training From Scratch ```bash # ACIL (CIFAR-100) python main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet32 \ --learning-rate 0.5 --label-smoothing 0 --base-epochs 300 --weight-decay 5e-4 \ --gamma 0.1 --buffer-size 8192 --cache-features --IL-batch-size 4096 ``` ```bash # ACIL (ImageNet-1k) python main.py ACIL --dataset ImageNet-1k --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet18 \ --learning-rate 0.5 --label-smoothing 0.05 --base-epochs 300 --weight-decay 5e-5 \ --gamma 0.1 --buffer-size 16384 --cache-features --IL-batch-size 4096 ``` ## Reproduction Details ### Difference Between the ACIL and the G-ACIL The G-ACIL is a general version of the ACIL for the general CIL setting. For the tradition CIL setting, the G-ACIL is equivalent to the ACIL. Thus, we use the same implementation in this repository. ### Benchmarks (B50, 25 phases, with `TrivialAugmentWide`) Metrics are shown in 95% confidence intervals ($\mu \pm 1.96\sigma$). | Dataset | Method | Backbone | Buffer Size | Average Accuracy (%) | Last Phase Accuracy (%) | | :---------: | :------------: | :-------: | :---------: | :------------------: | :---------------------: | | CIFAR-100 | ACIL & G-ACIL | ResNet-32 | 8192 | $71.047\pm0.252$ | $63.384\pm0.330$ | | CIFAR-100 | DS-AL | ResNet-32 | 8192 | $71.277\pm0.251$ | $64.043\pm0.184$ | | CIFAR-100 | GKEAL | ResNet-32 | 8192 | $70.371\pm0.168$ | $62.301\pm0.191$ | | ImageNet-1k | ACIL & G-ACIL | ResNet-18 | 16384 | $67.497\pm0.092$ | $58.349\pm0.111$ | | ImageNet-1k | DS-AL | ResNet-18 | 16384 | $68.354\pm0.084$ | $59.762\pm0.086$ | | ImageNet-1k | GKEAL | ResNet-18 | 16384 | $66.881\pm0.061$ | $57.295\pm0.105$ | ![Top-1 Accuracy](figures/acc@1.svg) ### Hyper-Parameters (Analytic Continual Learning) The backbones are frozen during the incremental learning process of our algorithm. You can use the `--cache-features` option to save the features output by the backbones to improve the efficiency of parameter adjustment. 1. **Buffer Size** For the ACIL, the buffer size means the *expansion size* of the random projection layer. For the GKEAL, the buffer size means the number of *center vectors* of the *Gaussian kernel embedding*. We summarize the "random projection" and the "Gaussian projection" into one concept "buffer" in the DS-AL. On most datasets, the performance of the algorithm first increases and then decreases as the buffer size increases. You can see further experiments on this hyperparameter in our papers. We recommend using a buffer size of 8192 on CIFAR-100 and 16384 or greater on ImageNet for optimal performance. A larger buffer size requires more memory. 2. **$\gamma$ (Coefficient of the Regularization Term)** For the dataset used in the papers, $\gamma$ is insensitive within a interval. However, a $\gamma$ that is too small may cause numerical stability problems in matrix inversion, and a $\gamma$ that is too large may cause under-fitting of the classifier. On both CIFAR-100 and ImageNet-1k, $\gamma$ is 0.1. When you migrate our algorithm to other datasets, we still recommend that you do some experiments to check whether $\gamma$ is appropriate. 3. **$\beta$ and $\sigma$ (GKEAL Only)** In the GKEAL, the width-adjusting parameter $\beta$ controls the width of the Gaussian kernels. There is a comfortable range for $\sigma$ at around $[5, 15]$ for CIFAR-100 and ImageNet-1k that gives good results, where $\beta = \frac{1}{2\sigma^2}$. 4. **Compensation Ratio $\mathcal{C}$ (DS-AL Only)** We recommend using the grid search to find the best compensation ratio in the interval $[0, 2]$. The best value is 0.6 for the CIFAR-100, while the best value for the ImageNet-1k is 1.5. Further analysis on hyper-parameters are shown in our papers. ### Hyper-Parameters (Base Training) In the base training process, the backbones reaches over 80% top-1 accuracy on the first half of CIFAR-100 (ResNet-32) and ImageNet-1k (ResNet-18). Important hyper-parameters are listed below. 1. **Learning Rate** In this implementation, we use a cosine scheduler instead of choosing the same piece-wise smooth scheduler as in the papers to reduce the number of hyper-parameters. We recommend using a learning rate of 0.5 (when the batch size is 256) on CIFAR-100 and ImageNet-1k to obtain better convergence. The number of epochs we use for provided backbones is 300. 2. **Label Smoothing and Weight Decay** Properly setting label smoothing and weight decay can help prevent over-fitting of the backbone. In CIFAR-100, label smoothing is not significantly helpful; while in ImageNet-1k, we empirically selected 0.05. For CIFAR-100, we choose a weight decay of 5e-4, while in ImageNet-1k, this value is 5e-5. 3. **Image Augmentation** Using image augmentation to obtain a more generalizable backbone in the base training dataset can significantly improve performance. No image augmentation is used in the experiments of our papers. But in this implementation, data augmentation is enabled on by default. **So using this implementation will achieve higher performance than reported in the papers (about 2%~5%)**. Note that we do not use any data augmentation during the re-alignment and the continual learning processes because each sample will be learned only once. # Cite Our Papers ```bib @InProceedings{ACIL_Zhuang_NeurIPS2022, author = {Zhuang, Huiping and Weng, Zhenyu and Wei, Hongxin and Xie, Renchunzi and Toh, Kar-Ann and Lin, Zhiping}, title = {{ACIL}: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection}, booktitle = {Advances in Neural Information Processing Systems}, editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, pages = {11602--11614}, publisher = {Curran Associates, Inc.}, volume = {35}, year = {2022}, url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/4b74a42fc81fc7ee252f6bcb6e26c8be-Paper-Conference.pdf} } @InProceedings{GKEAL_Zhuang_CVPR2023, author = {Zhuang, Huiping and Weng, Zhenyu and He, Run and Lin, Zhiping and Zeng, Ziqian}, title = {{GKEAL}: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task}, booktitle = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = jun, year = {2023}, pages = {7746--7755}, doi = {10.1109/CVPR52729.2023.00748} } @Article{DS-AL_Zhuang_AAAI2024, title = {{DS-AL}: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning}, author = {Zhuang, Huiping and He, Run and Tong, Kai and Zeng, Ziqian and Chen, Cen and Lin, Zhiping}, journal = {Proceedings of the AAAI Conference on Artificial Intelligence}, volume = {38}, number = {15}, pages = {17237--17244}, year = {2024}, month = mar, doi = {10.1609/aaai.v38i15.29670}, url = {https://ojs.aaai.org/index.php/AAAI/Article/view/29670} } @InProceedings{GACL_Zhuang_NeurIPS2024, title = {{GACL}: Exemplar-Free Generalized Analytic Continual Learning}, author = {Huiping Zhuang and Yizhu Chen and Di Fang and Run He and Kai Tong and Hongxin Wei and Ziqian Zeng and Cen Chen}, year = {2024}, booktitle = {Advances in Neural Information Processing Systems}, publisher = {Curran Associates, Inc.}, month = dec } @article{AEF-OCL_Zhuang_TVT2024, title = {Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task}, author = {Zhuang, Huiping and Fang, Di and Tong, Kai and Liu, Yuchen and Zeng, Ziqian and Zhou, Xu and Chen, Cen}, year = {2024}, journal = {IEEE Transactions on Vehicular Technology}, pages = {1-10}, doi = {10.1109/TVT.2024.3483557} } @misc{AIR_Fang_arXiv2024, title = {{AIR}: Analytic Imbalance Rectifier for Continual Learning}, author = {Di Fang and Yinan Zhu and Zhiping Lin and Cen Chen and Ziqian Zeng and Huiping Zhuang}, year = {2024}, month = aug, archivePrefix = {arXiv}, primaryClass = {cs.LG}, eprint = {2408.10349}, doi = {10.48550/arXiv.2408.10349}, url = {https://arxiv.org/abs/2408.10349}, } ``` ================================================ FILE: README_CN.md ================================================ 中文 | [English](README.md) # 解析持续学习 (Analytic Continual Learning) 该项目的工作已被发表在以下论文中: [1] Zhuang, Huiping, et al. "[ACIL: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection.](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4b74a42fc81fc7ee252f6bcb6e26c8be-Abstract-Conference.html)" Advances in Neural Information Processing Systems 35 (2022): 11602-11614. [2] Zhuang, Huiping, et al. "[GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task.](https://openaccess.thecvf.com/content/CVPR2023/html/Zhuang_GKEAL_Gaussian_Kernel_Embedded_Analytic_Learning_for_Few-Shot_Class_Incremental_CVPR_2023_paper.html)" Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. [3] Zhuang, Huiping, et al. "[DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning.](https://ojs.aaai.org/index.php/AAAI/article/view/29670)" Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. [4] Zhuang, Huiping, et al. "[GACL: Exemplar-Free Generalized Analytic Continual Learning](https://neurips.cc/virtual/2024/poster/95330)" Advances in Neural Information Processing Systems 37 (2024). [[OpenReview]](https://openreview.net/forum?id=P6aJ7BqYlc) [[arXiv]](https://arxiv.org/abs/2403.15706) [5] Zhuang, Huiping, et al. "[Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task.](https://ieeexplore.ieee.org/document/10721370)" IEEE Transactions on Vehicular Technology (2024). [6] Fang, Di, et al. "[AIR: Analytic Imbalance Rectifier for Continual Learning.](https://arxiv.org/abs/2408.10349)" arXiv preprint arXiv:2408.10349 (2024). ![](figures/acc_cmp.jpg) **欢迎加入我们的交流QQ群: [954528161](http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=qaK4W8Jw6d--VWHlx7iUs93T2qMJT9k_&authKey=5e8hSXX8rALjM12iGwrZ9BmRBP9iUfCuRGNCaZ3%2Bx0msiRFVcSwu%2FuZpeKig1XQH&noverify=0&group_code=954528161)。中文解读教程可以在[Bilibili](https://www.bilibili.com/video/BV1wq421A7YM/)中观看。** ## 解析学习的另一分支:解析联邦学习 我们开源了解析学习的另一分支——“[解析联邦学习](https://github.com/ZHUANGHP/Analytic-federated-learning)”相关的工作。 ## 环境配置 (Environment) 我们建议使用[Anaconda](https://anaconda.org/)来配置运行环境。 ```bash git clone --depth=1 git@github.com:ZHUANGHP/Analytic-continual-learning.git cd Analytic-continual-learning conda env create -f environment.yaml conda activate AL mkdir backbones ``` 您可以从[这里](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)下载预训练模型,来快速上手体验我们的算法。我们建议将预训练的骨干网络(backbone)提取在`backbones`文件夹下。 对于使用macOS系统或使用CPUs计算的用户,您需要删除`environment.yaml`文件中有关CUDA的项。 我们强烈建议您在Linux中运行我们的算法。当然如果Windows和macOS用户在运行时遇到任何问题也欢迎提交Issues。 ## 快速开始 (Quick Start) 在开始体验算法之前,请您先将基础训练权重(该[发布页](https://github.com/ZHUANGHP/Analytic-continual-learning/releases)中提供)放入`backbones`目录。由于本算法持续学习阶段不需要进行梯度计算,**您甚至可以在CPUs上运行我们的代码。** 这是一些参考案例: ```bash # ACIL (CIFAR-100, B50 25 phases) python main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # G-ACIL (CIFAR-100, B50 25 phases) python main.py G-ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # GKEAL (CIFAR-100, B50 10 phases) python main.py GKEAL --dataset CIFAR-100 --base-ratio 0.5 --phases 10 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --sigma 10 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # DS-AL (CIFAR-100, B50 50 phases) python main.py DS-AL --dataset CIFAR-100 --base-ratio 0.5 --phases 50 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet32 \ --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 0.6 --buffer-size 8192 \ --cache-features --backbone-path ./backbones/resnet32_CIFAR-100_0.5_None ``` ```bash # DS-AL (ImageNet-1k, B50 20 phases) python main.py DS-AL --dataset ImageNet-1k --base-ratio 0.5 --phases 20 \ --data-root ~/dataset --IL-batch-size 4096 --num-workers 16 --backbone resnet18 \ --gamma 0.1 --gamma-comp 0.1 --compensation-ratio 1.5 --buffer-size 16384 \ --cache-features --backbone-path ./backbones/resnet18_ImageNet-1k_0.5_None ``` ## 从零开始训练 (Training From Scratch) ```bash # ACIL (CIFAR-100) python main.py ACIL --dataset CIFAR-100 --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet32 \ --learning-rate 0.5 --label-smoothing 0 --base-epochs 300 --weight-decay 5e-4 \ --gamma 0.1 --buffer-size 8192 --cache-features --IL-batch-size 4096 ``` ```bash # ACIL (ImageNet-1k) python main.py ACIL --dataset ImageNet-1k --base-ratio 0.5 --phases 25 \ --data-root ~/dataset --batch-size 256 --num-workers 16 --backbone resnet18 \ --learning-rate 0.5 --label-smoothing 0.05 --base-epochs 300 --weight-decay 5e-5 \ --gamma 0.1 --buffer-size 16384 --cache-features --IL-batch-size 4096 ``` ## 复现的细节 (Reproduction Details) ### ACIL与G-ACIL之间的区别 G-ACIL是用于一般CIL设置的ACIL的通用版本。在传统的CIL任务上,G-ACIL相当于ACIL。因此我们在本仓库中使用了相同的实现。 ### 基准测试(B50,25阶段,使用`TrivialAugmentWide`) 以下指标有95%的置信水平($\mu \pm 1.96\sigma$): | Dataset | Method | Backbone | Buffer Size | Average Accuracy (%) | Last Phase Accuracy (%) | | :---------: | :-----------: | :-------: | :---------: | :------------------: | :---------------------: | | CIFAR-100 | ACIL & G-ACIL | ResNet-32 | 8192 | $71.047\pm0.252$ | $63.384\pm0.330$ | | CIFAR-100 | DS-AL | ResNet-32 | 8192 | $71.277\pm0.251$ | $64.043\pm0.184$ | | CIFAR-100 | GKEAL | ResNet-32 | 8192 | $70.371\pm0.168$ | $62.301\pm0.191$ | | ImageNet-1k | ACIL & G-ACIL | ResNet-18 | 16384 | $67.497\pm0.092$ | $58.349\pm0.111$ | | ImageNet-1k | DS-AL | ResNet-18 | 16384 | $68.354\pm0.084$ | $59.762\pm0.086$ | | ImageNet-1k | GKEAL | ResNet-18 | 16384 | $66.881\pm0.061$ | $57.295\pm0.105$ | ![Top-1 Accuracy](figures/acc@1.svg) ### 超参数(持续学习阶段) 在算法的持续学习阶段中,骨干网络(backbone)是被冻结的。您可以使用`--cache-features`选项保存骨干网络输出的特征,以提高参数调整的效率。下面列出了一些重要的超参数: 1. **Buffer Size** 对于ACIL,缓冲区(buffer)大小表示随机投影层的 *扩展尺寸(expansion size)* 。对于GKEAL,缓冲区大小是指 *高斯核嵌入(Gaussian kernel embedding)* 的 *中心向量(center vectors)* 的个数。在DS-AL中,我们将“随机投影(random projection)”和“高斯投影(Gaussian projection)”均归为“缓冲区”这一概念。 在大多数数据集上,随着缓冲区大小的增加,本算法的性能先增加后降低,您可以在我们的论文中看到关于这个超参数的详细实验。根据实验,我们建议在CIFAR-100上将缓冲区大小设置为8192,在ImageNet-1k上将缓冲区大小设置为16384或更大,以获得最佳性能。当然,较大的缓冲区大小代表着更多的内存占用。 2. **$\gamma$(正则化项的系数)** 对于论文中使用的数据集, $\gamma$ 在一定范围内不敏感。但是,太小的 $\gamma$ 可能会导致矩阵求逆过程所得数值不稳定,而太大的 $\gamma$ 可能会导致分类器欠拟合。根据实验,我们在CIFAR-100和ImageNet-1k上将 $\gamma$ 均设置为0.1。若您计划将我们的算法应用到其他数据集时,我们还是建议您做一些实验以检查 $\gamma$ 是否合适,避免无法充分发挥算法性能。 3. **$\beta$ and $\sigma$(只在GKEAL中设置)** 在GKEAL中,宽度调整(width-adjusting)参数 $\beta$ 控制高斯核(Gaussian kernels)的宽度。对于CIFAR-100和ImageNet-1k, $\sigma$ 设置在 $[5, 15]$ 左右时效果会较好,这里有转换关系 $\beta = \frac{1}{2\sigma^2}$ 。 4. **Compensation Ratio $\mathcal{C}$(只在DS-AL中设置)** 我们建议使用网格搜索(grid search)在区间 $[0,2]$ 中找到最佳补偿比(compensation ratio)。根据实验,我们建议在CIFAR-100和ImageNet-1k上分别将补偿比设置为0.6和1.5。 更为详细的超参数设置工作您可以在我们的论文中查阅。 ### 超参数(基础训练阶段) 在基础训练阶段中,骨干网络在CIFAR-100 (ResNet-32)和ImageNet-1k (ResNet-18)的前半数据集上达到了80%以上的top-1准确率。下面列出了一些重要的超参数: 1. **Learning Rate** 在本仓库实现中,我们使用了“余弦调整器(cosine scheduler)”,而不是像论文中那样使用“分段平滑调整器(piece-wise smooth scheduler)”,这可以有效的减少需要设置的超参数数量。我们建议在CIFAR-100和ImageNet-1k上将学习率设置为0.5(当批大小为256时)以获得更好的收敛性。此外,提供的骨干网络训练的epoch数是300。 2. **Label Smoothing and Weight Decay** 适当的设置标签平滑(label smoothing)和权重衰减(weight decay)可以防止骨干网络过拟合。有关标签平滑,在CIFAR-100中设置该参数没有显著的效果,在ImageNet-1k中我们设置为0.05;有关权重衰减,在CIFAR-100中我们设置为5e-4,在ImageNet-1k中我们设置为5e-5。 3. **Image Augmentation** 在基础训练数据集中使用图像增强可以获得有更好泛化能力的骨干网络,能够显著提高性能。在论文的实验中,我们并没有使用图像增强。而在本仓库实现中,默认情况下我们设置启用了数据增强。**因此使用本仓库实现有着相比论文中指标更高的性能(约2%~5%)。** 请注意,在重新对齐(re-alignment)和持续学习过程中,由于每个样本只学习一次,我们没有使用任何数据增强。 # 欢迎引用我们的论文 ```bib @InProceedings{ACIL_Zhuang_NeurIPS2022, author = {Zhuang, Huiping and Weng, Zhenyu and Wei, Hongxin and Xie, Renchunzi and Toh, Kar-Ann and Lin, Zhiping}, title = {{ACIL}: Analytic Class-Incremental Learning with Absolute Memorization and Privacy Protection}, booktitle = {Advances in Neural Information Processing Systems}, editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, pages = {11602--11614}, publisher = {Curran Associates, Inc.}, volume = {35}, year = {2022}, url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/4b74a42fc81fc7ee252f6bcb6e26c8be-Paper-Conference.pdf} } @InProceedings{GKEAL_Zhuang_CVPR2023, author = {Zhuang, Huiping and Weng, Zhenyu and He, Run and Lin, Zhiping and Zeng, Ziqian}, title = {{GKEAL}: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task}, booktitle = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = jun, year = {2023}, pages = {7746--7755}, doi = {10.1109/CVPR52729.2023.00748} } @Article{DS-AL_Zhuang_AAAI2024, title = {{DS-AL}: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning}, author = {Zhuang, Huiping and He, Run and Tong, Kai and Zeng, Ziqian and Chen, Cen and Lin, Zhiping}, journal = {Proceedings of the AAAI Conference on Artificial Intelligence}, volume = {38}, number = {15}, pages = {17237--17244}, year = {2024}, month = mar, doi = {10.1609/aaai.v38i15.29670}, url = {https://ojs.aaai.org/index.php/AAAI/Article/view/29670} } @InProceedings{GACL_Zhuang_NeurIPS2024, title = {{GACL}: Exemplar-Free Generalized Analytic Continual Learning}, author = {Huiping Zhuang and Yizhu Chen and Di Fang and Run He and Kai Tong and Hongxin Wei and Ziqian Zeng and Cen Chen}, year = {2024}, booktitle = {Advances in Neural Information Processing Systems}, publisher = {Curran Associates, Inc.}, month = dec } @article{AEF-OCL_Zhuang_TVT2024, title = {Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task}, author = {Zhuang, Huiping and Fang, Di and Tong, Kai and Liu, Yuchen and Zeng, Ziqian and Zhou, Xu and Chen, Cen}, year = {2024}, journal = {IEEE Transactions on Vehicular Technology}, pages = {1-10}, doi = {10.1109/TVT.2024.3483557} } @misc{AIR_Fang_arXiv2024, title = {{AIR}: Analytic Imbalance Rectifier for Continual Learning}, author = {Di Fang and Yinan Zhu and Zhiping Lin and Cen Chen and Ziqian Zeng and Huiping Zhuang}, year = {2024}, month = aug, archivePrefix = {arXiv}, primaryClass = {cs.LG}, eprint = {2408.10349}, doi = {10.48550/arXiv.2408.10349}, url = {https://arxiv.org/abs/2408.10349}, } ``` ================================================ FILE: analytic/ACIL.py ================================================ # -*- coding: utf-8 -*- """ Implementation of the ACIL [1] and the G-ACIL [2]. The G-ACIL is a generalization of the ACIL in the generalized setting. For the popular setting, the G-ACIL is equivalent to the ACIL. References: [1] Zhuang, Huiping, et al. "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." Advances in Neural Information Processing Systems 35 (2022): 11602-11614. [2] Zhuang, Huiping, et al. "G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning" arXiv preprint arXiv:2403.15706 (2024). """ import torch from os import path from tqdm import tqdm from typing import Any, Dict, Optional, Sequence from utils import set_weight_decay, validate from torch._prims_common import DeviceLikeType from .Buffer import RandomBuffer from torch.nn import DataParallel from .Learner import Learner, loader_t from .AnalyticLinear import AnalyticLinear, RecursiveLinear class ACIL(torch.nn.Module): def __init__( self, backbone_output: int, backbone: torch.nn.Module = torch.nn.Flatten(), buffer_size: int = 8192, gamma: float = 1e-3, device=None, dtype=torch.double, linear: type[AnalyticLinear] = RecursiveLinear, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.backbone = backbone self.backbone_output = backbone_output self.buffer_size = buffer_size self.buffer = RandomBuffer(backbone_output, buffer_size, **factory_kwargs) self.analytic_linear = linear(buffer_size, gamma, **factory_kwargs) self.eval() @torch.no_grad() def feature_expansion(self, X: torch.Tensor) -> torch.Tensor: return self.buffer(self.backbone(X)) @torch.no_grad() def forward(self, X: torch.Tensor) -> torch.Tensor: return self.analytic_linear(self.feature_expansion(X)) @torch.no_grad() def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None: Y = torch.nn.functional.one_hot(y) X = self.feature_expansion(X) self.analytic_linear.fit(X, Y) @torch.no_grad() def update(self) -> None: self.analytic_linear.update() class ACILLearner(Learner): """ This implementation is for the G-ACIL [2], a general version of the ACIL [1] that supports mini-batch learning and the general CIL setting. In the traditional CIL settings, the G-ACIL is equivalent to the ACIL. """ def __init__( self, args: Dict[str, Any], backbone: torch.nn.Module, backbone_output: int, device=None, all_devices: Optional[Sequence[DeviceLikeType]] = None, ) -> None: super().__init__(args, backbone, backbone_output, device, all_devices) self.learning_rate: float = args["learning_rate"] self.buffer_size: int = args["buffer_size"] self.gamma: float = args["gamma"] self.base_epochs: int = args["base_epochs"] self.warmup_epochs: int = args["warmup_epochs"] self.make_model() def base_training( self, train_loader: loader_t, val_loader: loader_t, baseset_size: int, ) -> None: model = torch.nn.Sequential( self.backbone, torch.nn.Linear(self.backbone_output, baseset_size), ).to(self.device, non_blocking=True) model = self.wrap_data_parallel(model) if self.args["separate_decay"]: params = set_weight_decay(model, self.args["weight_decay"]) else: params = model.parameters() optimizer = torch.optim.SGD( params, lr=self.learning_rate, momentum=self.args["momentum"], weight_decay=self.args["weight_decay"], ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.base_epochs - self.warmup_epochs, eta_min=1e-6 # type: ignore ) if self.warmup_epochs > 0: warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-3, total_iters=self.warmup_epochs, ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [warmup_scheduler, scheduler], [self.warmup_epochs] ) criterion = torch.nn.CrossEntropyLoss( label_smoothing=self.args["label_smoothing"] ).to(self.device, non_blocking=True) best_acc = 0.0 logging_file_path = path.join(self.args["saving_root"], "base_training.csv") logging_file = open(logging_file_path, "w", buffering=1) print( "epoch", "best_acc@1", "loss", "acc@1", "acc@5", "f1-micro", "training_loss", "training_acc@1", "training_acc@5", "training_f1-micro", "training_learning-rate", file=logging_file, sep=",", ) for epoch in range(self.base_epochs + 1): if epoch != 0: print( f"Base Training - Epoch {epoch}/{self.base_epochs}", f"(Learning Rate: {optimizer.state_dict()['param_groups'][0]['lr']})", ) model.train() for X, y in tqdm(train_loader, "Training"): X: torch.Tensor = X.to(self.device, non_blocking=True) y: torch.Tensor = y.to(self.device, non_blocking=True) assert y.max() < baseset_size optimizer.zero_grad(set_to_none=True) logits = model(X) loss: torch.Tensor = criterion(logits, y) loss.backward() optimizer.step() scheduler.step() # Validation on training set model.eval() train_meter = validate( model, train_loader, baseset_size, desc="Training (Validation)" ) print( f"loss: {train_meter.loss:.4f}", f"acc@1: {train_meter.accuracy * 100:.3f}%", f"acc@5: {train_meter.accuracy5 * 100:.3f}%", f"f1-micro: {train_meter.f1_micro * 100:.3f}%", sep=" ", ) val_meter = validate(model, val_loader, baseset_size, desc="Testing") if val_meter.accuracy > best_acc: best_acc = val_meter.accuracy if epoch != 0: self.save_object( (self.backbone, X.shape[1], self.backbone_output), "backbone.pth", ) # Validation on testing set print( f"loss: {val_meter.loss:.4f}", f"acc@1: {val_meter.accuracy * 100:.3f}%", f"acc@5: {val_meter.accuracy5 * 100:.3f}%", f"f1-micro: {val_meter.f1_micro * 100:.3f}%", f"best_acc@1: {best_acc * 100:.3f}%", sep=" ", ) print( epoch, best_acc, val_meter.loss, val_meter.accuracy, val_meter.accuracy5, val_meter.f1_micro, train_meter.loss, train_meter.accuracy, train_meter.accuracy5, train_meter.f1_micro, optimizer.state_dict()["param_groups"][0]["lr"], file=logging_file, sep=",", ) logging_file.close() self.backbone.eval() self.make_model() def make_model(self) -> None: self.model = ACIL( self.backbone_output, self.wrap_data_parallel(self.backbone), self.buffer_size, self.gamma, device=self.device, dtype=torch.double, linear=RecursiveLinear, ) @torch.no_grad() def learn( self, data_loader: loader_t, incremental_size: int, desc: str = "Incremental Learning", ) -> None: self.model.eval() for X, y in tqdm(data_loader, desc=desc): X: torch.Tensor = X.to(self.device, non_blocking=True) y: torch.Tensor = y.to(self.device, non_blocking=True) self.model.fit(X, y, increase_size=incremental_size) def before_validation(self) -> None: self.model.update() def inference(self, X: torch.Tensor) -> torch.Tensor: return self.model(X) @torch.no_grad() def wrap_data_parallel(self, model: torch.nn.Module) -> torch.nn.Module: if self.all_devices is not None and len(self.all_devices) > 1: return DataParallel(model, self.all_devices, output_device=self.device) # type: ignore return model ================================================ FILE: analytic/AEFOCL.py ================================================ # -*- coding: utf-8 -*- """ Implementation of the AEF-OCL [1], an analytic method for imbalanced continual learning. References: [1] Zhuang, Huiping, et al. "Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task" arXiv preprint arXiv:2405.17779 (2024). """ from copy import deepcopy import torch from tqdm import tqdm from .ACIL import ACILLearner, ACIL from .AnalyticLinear import AnalyticLinear, RecursiveLinear __all__ = ["AEFOCL", "AEFOCLLearner"] class AEFOCL(ACIL): """ Network structure of the AEF-OCL [1], an analytic method for imbalanced continual learning. References: [1] Zhuang, Huiping, et al. "Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task" arXiv preprint arXiv:2405.17779 (2024). """ def __init__( self, backbone_output: int, backbone: torch.nn.Module = torch.nn.Flatten(), buffer_size: int = 8192, gamma: float = 1e-3, noise: float = 1, device=None, dtype=torch.double, linear: type[AnalyticLinear] = RecursiveLinear, ) -> None: super().__init__( backbone_output, backbone, buffer_size, gamma, device, dtype, linear ) self._linear_log = dict() # History prototype self.noise = noise # Expectation of the prototypes E[X] self.register_buffer("ex", torch.zeros((0, backbone_output), dtype=torch.double)) self.ex: torch.Tensor # Expectation of the squares of the prototypes E[X^2] self.register_buffer("ex2", torch.zeros((0, backbone_output), dtype=torch.double)) self.ex2: torch.Tensor # Number of the samples of the prototypes self.register_buffer("cnt", torch.zeros((0,), dtype=torch.long)) self.cnt: torch.Tensor # Set the device self.to(device) @torch.no_grad() def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None: for name, buffer in self._linear_log.items(): self.analytic_linear.register_buffer(name, buffer) self._linear_log.clear() X = self.backbone(X) if (increment_size := int(y.max().item()) - self.ex.shape[0] + 1) > 0: # self.cnt tail = torch.zeros((increment_size,)).to(self.cnt) self.cnt = torch.concat((self.cnt, tail), dim=0) # self.ex tail = torch.zeros((increment_size, self.ex.shape[1])).to(self.ex) self.ex = torch.concat((self.ex, tail), dim=0) # self.ex2 tail = torch.zeros((increment_size, self.ex2.shape[1])).to(self.ex2) self.ex2 = torch.concat((self.ex2, tail), dim=0) labels, counts = torch.unique(y, return_counts=True) self.cnt[labels] += counts for i in labels: X_i = X[y == i] # Calculate E[X] self.ex[i] += torch.sum(X_i.to(self.ex), dim=0) # Calculate E[X^2] self.ex2[i] += torch.sum(torch.square(X_i.to(self.ex2)), dim=0) X = self.buffer(X) Y = torch.nn.functional.one_hot(y) self.analytic_linear.fit(X, Y) def update(self) -> None: peak_cnt = int(self.cnt.max()) mean = self.proto_mean std = self.proto_std print("Counts:", self.cnt.tolist()) # Backup the iterative classifier for name, buffer in self.analytic_linear.named_buffers(): self._linear_log[name] = buffer.clone().detach() aug_bar = tqdm( desc="Augmenting", total=(peak_cnt * len(self.cnt.nonzero()) - int(self.cnt.sum())), ) for i in self.cnt.nonzero(): i = int(i.item()) rest_cnt = int(peak_cnt - self.cnt[i]) while rest_cnt > 0: fill_cnt = min(rest_cnt, 8192) fill_y = torch.empty((fill_cnt,), dtype=torch.long).fill_(i) fill_proto = torch.randn((fill_cnt, self.buffer.in_features)).to( self.buffer.weight ) fill_proto = ( fill_proto * std[i][None, :] * self.noise + mean[i][None, :] ) fill_proto = self.buffer(fill_proto) fill_y = torch.nn.functional.one_hot(fill_y) self.analytic_linear.fit(fill_proto, fill_y) aug_bar.update(fill_cnt) rest_cnt -= fill_cnt self.analytic_linear.update() @property def proto_mean(self) -> torch.Tensor: return self.ex / self.cnt[:, None] @property def proto_std(self) -> torch.Tensor: std = self.ex2 / self.cnt[:, None] - torch.square(self.proto_mean) std[torch.isnan(std)] = 0 assert (std >= 0).all() proto_std = torch.sqrt(std * (self.cnt / (self.cnt - 1))[:, None]) proto_std[torch.isnan(proto_std)] = 0 return proto_std class AEFOCLLearner(ACILLearner): """ Learner of the AEF-OCL [1], an analytic method for imbalanced continual learning. References: [1] Zhuang, Huiping, et al. "Online Analytic Exemplar-Free Continual Learning with Large Models for Imbalanced Autonomous Driving Task" arXiv preprint arXiv:2405.17779 (2024). """ def make_model(self) -> None: self.model = AEFOCL( self.backbone_output, self.backbone, self.buffer_size, self.gamma, device=self.device, dtype=torch.double, linear=RecursiveLinear, ) ================================================ FILE: analytic/AIR.py ================================================ # -*- coding: utf-8 -*- """ Implementation of the AIR [1], an online exemplar-free generalized CIL approach on imbalanced datasets. References: [1] Fang, Di, et al. "AIR: Analytic Imbalance Rectifier for Continual Learning." arXiv preprint arXiv:2408.10349 (2024). """ import torch from .ACIL import ACILLearner, ACIL from .AnalyticLinear import GeneralizedARM __all__ = ["AIR", "AIRLearner", "GeneralizedAIRLearner"] class AIR(ACIL): def fit(self, X: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> None: X = self.feature_expansion(X) self.analytic_linear.fit(X, y) class AIRLearner(ACILLearner): def make_model(self) -> None: self.model = AIR( self.backbone_output, self.wrap_data_parallel(self.backbone), self.buffer_size, self.gamma, device=self.device, dtype=torch.double, linear=GeneralizedARM, ) class GeneralizedAIRLearner(AIRLearner): pass ================================================ FILE: analytic/AnalyticLinear.py ================================================ # -*- coding: utf-8 -*- """ Basic analytic linear modules for the analytic continual learning [1-5]. References: [1] Zhuang, Huiping, et al. "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." Advances in Neural Information Processing Systems 35 (2022): 11602-11614. [2] Zhuang, Huiping, et al. "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. [3] Zhuang, Huiping, et al. "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. [4] Zhuang, Huiping, et al. "G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning" arXiv preprint arXiv:2403.15706 (2024). [5] Fang, Di, et al. "AIR: Analytic Imbalance Rectifier for Continual Learning." arXiv preprint arXiv:2408.10349 (2024). """ import torch from torch.nn import functional as F from typing import Optional, Union from abc import abstractmethod, ABCMeta class AnalyticLinear(torch.nn.Linear, metaclass=ABCMeta): def __init__( self, in_features: int, gamma: float = 1e-1, bias: bool = False, device: Optional[Union[torch.device, str, int]] = None, dtype=torch.double, ) -> None: super(torch.nn.Linear, self).__init__() # Skip the Linear class factory_kwargs = {"device": device, "dtype": dtype} self.gamma: float = gamma self.bias: bool = bias self.dtype = dtype # Linear Layer if bias: in_features += 1 weight = torch.zeros((in_features, 0), **factory_kwargs) self.register_buffer("weight", weight) @torch.inference_mode() def forward(self, X: torch.Tensor) -> torch.Tensor: X = X.to(self.weight) if self.bias: X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1) return X @ self.weight @property def in_features(self) -> int: if self.bias: return self.weight.shape[0] - 1 return self.weight.shape[0] @property def out_features(self) -> int: return self.weight.shape[1] def reset_parameters(self) -> None: # Following the equation (4) of ACIL, self.weight is set to \hat{W}_{FCN}^{-1} self.weight = torch.zeros((self.weight.shape[0], 0)).to(self.weight) @abstractmethod def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None: raise NotImplementedError() def update(self) -> None: assert torch.isfinite(self.weight).all(), ( "Pay attention to the numerical stability! " "A possible solution is to increase the value of gamma. " "Setting self.dtype=torch.double also helps." ) class RecursiveLinear(AnalyticLinear): def __init__( self, in_features: int, gamma: float = 1e-1, bias: bool = False, device: Optional[Union[torch.device, str, int]] = None, dtype=torch.double, ) -> None: super().__init__(in_features, gamma, bias, device, dtype) factory_kwargs = {"device": device, "dtype": dtype} # Regularized Feature Autocorrelation Matrix (RFAuM) self.R: torch.Tensor R = torch.eye(self.weight.shape[0], **factory_kwargs) / self.gamma self.register_buffer("R", R) @torch.no_grad() def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None: """The core code of the ACIL and the G-ACIL. This implementation, which is different but equivalent to the equations shown in [1], is proposed in the G-ACIL [4], which supports mini-batch learning and the general CIL setting. """ X, Y = X.to(self.weight), Y.to(self.weight) if self.bias: X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1) num_targets = Y.shape[1] if num_targets > self.out_features: increment_size = num_targets - self.out_features tail = torch.zeros((self.weight.shape[0], increment_size)).to(self.weight) self.weight = torch.cat((self.weight, tail), dim=1) elif num_targets < self.out_features: increment_size = self.out_features - num_targets tail = torch.zeros((Y.shape[0], increment_size)).to(Y) Y = torch.cat((Y, tail), dim=1) # Please update your PyTorch & CUDA if the `cusolver error` occurs. # If you insist on using this version, doing the `torch.inverse` on CPUs might help. # >>> K_inv = torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T # >>> K = torch.inverse(K_inv.cpu()).to(self.weight.device) K = torch.inverse(torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T) # Equation (10) of ACIL self.R -= self.R @ X.T @ K @ X @ self.R # Equation (9) of ACIL self.weight += self.R @ X.T @ (Y - X @ self.weight) class GeneralizedARM(AnalyticLinear): """Analytic Re-weighting Module (ARM) for generalized CIL.""" def __init__( self, in_features: int, gamma: float = 1e-1, bias: bool = False, device: Optional[Union[torch.device, str, int]] = None, dtype=torch.double, ) -> None: super().__init__(in_features, gamma, bias, device, dtype) factory_kwargs = {"device": device, "dtype": dtype} weight = torch.zeros((in_features, 0), **factory_kwargs) self.register_buffer("weight", weight) A = torch.zeros((0, in_features, in_features), **factory_kwargs) self.register_buffer("A", A) C = torch.zeros((in_features, 0), **factory_kwargs) self.register_buffer("C", C) self.cnt = torch.zeros(0, dtype=torch.int, device=device) @property def out_features(self) -> int: return self.C.shape[1] @torch.inference_mode() def fit(self, X: torch.Tensor, y: torch.Tensor) -> None: X = X.to(self.weight) # Bias if self.bias: X = torch.concat((X, torch.ones(X.shape[0], 1)), dim=-1) # GCIL num_targets = int(y.max()) + 1 if num_targets > self.out_features: increment_size = num_targets - self.out_features torch.cuda.empty_cache() # Increment C tail = torch.zeros((self.C.shape[0], increment_size)).to(self.weight) self.C = torch.concat((self.C, tail), dim=1) # Increment cnt tail = torch.zeros((increment_size,)).to(self.cnt) self.cnt = torch.concat((self.cnt, tail)) # Increment A tail = torch.zeros((increment_size, self.in_features, self.in_features)) self.A = torch.concat((self.A, tail.to(self.A))) torch.cuda.empty_cache() else: num_targets = self.out_features # ACIL Y = F.one_hot(y, max(num_targets, num_targets)).to(self.C) self.C += X.T @ Y # Label Balancing y_labels, label_cnt = torch.unique(y, sorted=True, return_counts=True) y_labels, label_cnt = y_labels.to(self.cnt.device), label_cnt.to( self.cnt.device ) self.cnt[y_labels] += label_cnt # Accumulate for i in range(num_targets): X_i = X[y == i] self.A[i] += X_i.T @ X_i @torch.inference_mode() def update(self): cnt_inv = 1 / self.cnt.to(self.dtype) cnt_inv[torch.isinf(cnt_inv)] = 0 # replace inf with 0 cnt_inv *= len(self.cnt) / cnt_inv.sum() weighted_A = torch.sum(cnt_inv[:, None, None].mul(self.A), dim=0) A = weighted_A + self.gamma * torch.eye(self.in_features).to(self.A) C = self.C.mul(cnt_inv[None, :]) self.weight = torch.inverse(A) @ C ================================================ FILE: analytic/Buffer.py ================================================ # -*- coding: utf-8 -*- """ Buffer layers for the analytic learning based CIL [1-4]. References: [1] Zhuang, Huiping, et al. "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." Advances in Neural Information Processing Systems 35 (2022): 11602-11614. [2] Zhuang, Huiping, et al. "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. [3] Zhuang, Huiping, et al. "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. [4] Zhuang, Huiping, et al. "G-ACIL: Analytic Learning for Exemplar-Free Generalized Class Incremental Learning" arXiv preprint arXiv:2403.15706 (2024). """ import torch from typing import Optional, Union, Callable from abc import ABCMeta, abstractmethod activation_t = Union[Callable[[torch.Tensor], torch.Tensor], torch.nn.Module] class Buffer(torch.nn.Module, metaclass=ABCMeta): def __init__(self) -> None: super().__init__() @abstractmethod def forward(self, X: torch.Tensor) -> torch.Tensor: raise NotImplementedError() class RandomBuffer(torch.nn.Linear, Buffer): def __init__( self, in_features: int, out_features: int, bias: bool = False, device=None, dtype=torch.float, activation: Optional[activation_t] = torch.relu_, ) -> None: super(torch.nn.Linear, self).__init__() factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features self.activation: activation_t = ( torch.nn.Identity() if activation is None else activation ) W = torch.empty((out_features, in_features), **factory_kwargs) b = torch.empty(out_features, **factory_kwargs) if bias else None # Using buffer instead of parameter self.register_buffer("weight", W) self.register_buffer("bias", b) # Random Initialization self.reset_parameters() @torch.no_grad() def forward(self, X: torch.Tensor) -> torch.Tensor: X = X.to(self.weight) return self.activation(super().forward(X)) class GaussianKernel(Buffer): def __init__( self, mean: torch.Tensor, sigma: float = 1, device=None, dtype=torch.float ) -> None: super().__init__() self.device = device self.dtype = dtype factory_kwargs = {"device": device, "dtype": dtype} assert len(mean.shape) == 2, "The mean should be a 2D tensor." mean = mean[None, :, :].to(**factory_kwargs) beta = 1 / (2 * (sigma**2)) self.register_buffer("mean", mean) self.register_buffer("beta", torch.tensor(beta, **factory_kwargs)) @torch.no_grad() def forward(self, X: torch.Tensor) -> torch.Tensor: X = torch.square_(torch.cdist(X.to(self.mean), self.mean)) return torch.exp_(X.mul_(-self.beta)) def init(self, X: torch.Tensor, size: Optional[int] = None) -> None: if size is not None: if size <= X.shape[0]: idx = torch.randperm(size).to(X.device) X = X[idx] else: # The buffer size is suggested to be greater than the number of initial samples. # Generate center vectors randomly n_require = size - X.shape[0] W_proj = torch.normal(mean=0, std=1, size=(n_require, X.shape[0])).to(X) W_proj /= torch.sum(W_proj, dim=0) X = torch.cat([X, W_proj @ X], dim=0) self.mean = X.to(self.mean) ================================================ FILE: analytic/DSAL.py ================================================ # -*- coding: utf-8 -*- """ Implementation of the DS-AL [1]. References: [1] Zhuang, Huiping, et al. "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. """ import torch from .ACIL import ACILLearner from typing import Callable, Dict, Any, Optional, Sequence from .AnalyticLinear import AnalyticLinear, RecursiveLinear from .Buffer import activation_t, RandomBuffer from torch._prims_common import DeviceLikeType class DSAL(torch.nn.Module): def __init__( self, backbone_output: int, backbone: Callable[[torch.Tensor], torch.Tensor] = torch.nn.Flatten(), expansion_size: int = 8192, gamma_main: float = 1e-3, gamma_comp: float = 1e-3, C: float = 1, activation_main: activation_t = torch.relu, activation_comp: activation_t = torch.tanh, device=None, dtype=torch.double, linear: type[AnalyticLinear] = RecursiveLinear, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.backbone = backbone self.expansion_size = expansion_size self.buffer = RandomBuffer( backbone_output, expansion_size, activation=torch.nn.Identity(), **factory_kwargs ) # The main stream self.activation_main = activation_main self.main_stream = linear(expansion_size, gamma_main, **factory_kwargs) # The compensation stream self.C = C self.activation_comp = activation_comp self.comp_stream = linear(expansion_size, gamma_comp, **factory_kwargs) self.eval() @torch.no_grad() def forward(self, X: torch.Tensor) -> torch.Tensor: X = self.buffer(self.backbone(X)) X_main = self.main_stream(self.activation_main(X)) X_comp = self.comp_stream(self.activation_comp(X)) return X_main + self.C * X_comp @torch.no_grad() def fit(self, X: torch.Tensor, y: torch.Tensor, increase_size: int) -> None: num_classes = max(self.main_stream.out_features, int(y.max().item()) + 1) Y_main = torch.nn.functional.one_hot(y, num_classes=num_classes) X = self.buffer(self.backbone(X)) # Train the main stream X_main = self.activation_main(X) self.main_stream.fit(X_main, Y_main) self.main_stream.update() # Previous label cleansing (PLC) Y_comp = Y_main - self.main_stream(X_main) Y_comp[:, :-increase_size] = 0 # Train the compensation stream X_comp = self.activation_comp(X) self.comp_stream.fit(X_comp, Y_comp) @torch.no_grad() def update(self) -> None: self.main_stream.update() self.comp_stream.update() class DSALLearner(ACILLearner): def __init__( self, args: Dict[str, Any], backbone: torch.nn.Module, backbone_output: int, device=None, all_devices: Optional[Sequence[DeviceLikeType]] = None, ) -> None: self.gamma_comp = args["gamma_comp"] self.compensation_ratio = args["compensation_ratio"] super().__init__(args, backbone, backbone_output, device, all_devices) def make_model(self) -> None: self.model = DSAL( self.backbone_output, self.backbone, self.buffer_size, self.gamma, self.gamma_comp, self.compensation_ratio, device=self.device, dtype=torch.double, linear=RecursiveLinear, ) ================================================ FILE: analytic/GKEAL.py ================================================ # -*- coding: utf-8 -*- """ Implementation of the GKEAL [1]. The GKEAL is a CIL method specially proposed for the few-shot CIL. But the implementation here is just a simplified version for common CIL settings. Compared with the method proposed in the paper, we do not perform image augmentation here. Each sample will only be learned once by default. References: [1] Zhuang, Huiping, et al. "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. """ import torch from tqdm import tqdm from typing import Dict, Any, Sequence, Optional from torch._prims_common import DeviceLikeType from .Learner import loader_t from .ACIL import ACIL, ACILLearner from .Buffer import GaussianKernel from .AnalyticLinear import AnalyticLinear, RecursiveLinear class GKEAL(ACIL): def __init__( self, backbone_output: int, backbone: torch.nn.Module = torch.nn.Flatten(), buffer_size: int = 512, gamma: float = 1e-3, sigma: float = 10, device=None, dtype=torch.double, linear: type[AnalyticLinear] = RecursiveLinear, ): super().__init__( backbone_output, backbone, buffer_size, gamma, device, dtype, linear ) self.buffer = GaussianKernel( torch.zeros((self.buffer_size, self.backbone_output)), sigma, device=device, dtype=dtype, ) class GKEALLearner(ACILLearner): def __init__( self, args: Dict[str, Any], backbone: torch.nn.Module, backbone_output: int, device=None, all_devices: Optional[Sequence[DeviceLikeType]] = None, ) -> None: self.initialized = False # The width-adjusting parameter β controls the width of the Gaussian kernels. # There is a comfortable range for σ at around [5, 15] for CIFAR-100 and ImageNet-1k # that gives good results, where β = 1 / (2σ²). self.sigma = args["sigma"] super().__init__(args, backbone, backbone_output, device, all_devices) def make_model(self) -> None: self.model = GKEAL( self.backbone_output, self.backbone, self.buffer_size, self.gamma, self.sigma, device=self.device, dtype=torch.double, linear=RecursiveLinear, ) @torch.no_grad() def learn( self, data_loader: loader_t, incremental_size: int, desc: str = "Incremental Learning", ) -> None: torch.cuda.empty_cache() if self.initialized: return super().learn(data_loader, incremental_size, desc) total_X = [] total_y = [] for X, y in tqdm(data_loader, desc="Selecting center vectors"): X: torch.Tensor = X.to(self.device, non_blocking=True) y: torch.Tensor = y.to(self.device, non_blocking=True) X = self.backbone(X) total_X.append(X) total_y.append(y) self.model.buffer.init(torch.cat(total_X), self.buffer_size) torch.cuda.empty_cache() for X, y in tqdm(zip(total_X, total_y), total=len(total_X), desc=desc): X = self.model.buffer(X) Y = torch.nn.functional.one_hot(y, incremental_size) self.model.analytic_linear.fit(X, Y) self.model.analytic_linear.update() self.initialized = True ================================================ FILE: analytic/Learner.py ================================================ import torch from os import path from abc import ABCMeta, abstractmethod from torch.utils.data import DataLoader from torch._prims_common import DeviceLikeType from typing import Union, Dict, Any, Optional, Sequence loader_t = DataLoader[Union[torch.Tensor, torch.Tensor]] class Learner(metaclass=ABCMeta): def __init__( self, args: Dict[str, Any], backbone: torch.nn.Module, backbone_output: int, device=None, all_devices: Optional[Sequence[DeviceLikeType]] = None, ) -> None: self.args = args self.backbone = backbone self.backbone_output = backbone_output self.device = device self.all_devices = all_devices self.model: torch.nn.Module @abstractmethod def base_training( self, train_loader: loader_t, val_loader: loader_t, baseset_size: int, ) -> None: raise NotImplementedError() @abstractmethod def learn( self, data_loader: loader_t, incremental_size: int, desc: str = "Incremental Learning" ) -> None: raise NotImplementedError() @abstractmethod def before_validation() -> None: raise NotImplementedError() @abstractmethod def inference(self, X: torch.Tensor) -> torch.Tensor: raise NotImplementedError() def save_object(self, model, file_name: str) -> None: torch.save(model, path.join(self.args["saving_root"], file_name)) def __call__(self, X: torch.Tensor) -> torch.Tensor: return self.inference(X) ================================================ FILE: analytic/__init__.py ================================================ # -*- coding: utf-8 -*- from .Learner import Learner from .Buffer import Buffer, RandomBuffer, GaussianKernel from .AnalyticLinear import AnalyticLinear, RecursiveLinear from .ACIL import ACIL, ACILLearner from .DSAL import DSAL, DSALLearner from .GKEAL import GKEAL, GKEALLearner from .AEFOCL import AEFOCL, AEFOCLLearner from .AIR import AIRLearner, GeneralizedAIRLearner __all__ = [ "Learner", "Buffer", "RandomBuffer", "GaussianKernel", "AnalyticLinear", "RecursiveLinear", "ACIL", "DSAL", "GKEAL", "AEFOCL", "ACILLearner", "DSALLearner", "GKEALLearner", "AEFOCLLearner", "AIRLearner", "GeneralizedAIRLearner", ] ================================================ FILE: config.py ================================================ import argparse from models import models from typing import Any, Dict from os import path, makedirs from datasets import dataset_list from datetime import datetime import yaml from sys import argv from analytic import ( ACILLearner, DSALLearner, GKEALLearner, AEFOCLLearner, AIRLearner, GeneralizedAIRLearner, Learner, ) ALL_METHODS: dict[str, type[Learner]] = { "ACIL": ACILLearner, "G-ACIL": ACILLearner, # The G-ACIL is a generalization of the ACIL in the generalized setting. "DS-AL": DSALLearner, "GKEAL": GKEALLearner, "AEF-OCL": AEFOCLLearner, "AIR": AIRLearner, "G-AIR": GeneralizedAIRLearner, # The G-AIL is a generalization of the AIR for generalized CIL. } __all__ = ["load_args", "ALL_METHODS"] _parser = argparse.ArgumentParser(description="Analytic Continual Learning") # Method Options _parser.add_argument( "method", choices=ALL_METHODS.keys(), help="The method to use for continual learning.", ) _parser.add_argument( "--exp-name", type=str, default="", help="Name of the experiment", ) _parser.add_argument( "--cpu-only", action="store_true", help="Run the program on CPU only.", ) _parser.add_argument( "--gpus", default=None, type=int, action="extend", nargs="+", help="List of GPUs to use.", ) # Dataset settings _data_group = _parser.add_argument_group("Dataset arguments") _data_group.add_argument( "-d", "--dataset", default="CIFAR-100", choices=dataset_list.keys(), ) _data_group.add_argument( "--data-root", metavar="DIR", type=str, help="Root path to the dataset", default="~/dataset", ) _data_group.add_argument( "-j", "--num-workers", default=8, type=int, metavar="N", help="Number of data loading workers (default: 8)", ) _data_group.add_argument( "--base-ratio", default=0.5, type=float, help="The ratio of base classes in the training set.", ) _data_group.add_argument( "--phases", "--tasks", default=10, type=int, help="Number of incremental phases (tasks).", ) _data_group.add_argument( "-b", "--batch-size", default=256, type=int, metavar="N", help="The size of one mini-batch per GPU.", ) _data_group.add_argument( "--cache-features", action="store_true", help="Load the features extracted by the frozen backbone to speed up inference.", ) # Model settings _model_group = _parser.add_argument_group("Model arguments") _model_group.add_argument( "-a", "-arch", "--backbone", type=str, default="resnet32", help="model to use for training", choices=models.keys(), metavar="ARCH", ) _model_group.add_argument( "--cache-path", "--backbone-path", metavar="DIR", type=str, help=( "Path to the base pretrain backbone." "If file exists, the base training will be skipped." ), ) # Training Settings _model_group.add_argument("--seed", default=None, type=int, help="Seed for models.") _model_group.add_argument( "--dataset-seed", default=None, type=int, help="Seed for shuffling the dataset." ) # Base training arguments _base_group = _parser.add_argument_group("Base training arguments") _base_group.add_argument( "--base-epochs", default=300, type=int, metavar="N", help="Number of total epochs to run for base training.", ) _base_group.add_argument( "--warmup-epochs", default=10, type=int, metavar="N", help="Number of warmup epochs.", ) _base_group.add_argument( "-lr", "--learning-rate", default=0.5, type=float, metavar="LR", help="Initial learning rate", ) _base_group.add_argument( "--momentum", default=0.9, type=float, metavar="M", help="Momentum for SGD" ) _base_group.add_argument( "--wd", "--weight-decay", default=5e-4, type=float, metavar="W", dest="weight_decay", ) _base_group.add_argument( "--separate-decay", action="store_true", help="Separating the normalization parameters from the rest of the model parameters", ) _base_group.add_argument("--label-smoothing", default=0.05, type=float) # IL hyper-parameters _il_group = _parser.add_argument_group("IL Hyper-parameters") _il_group.add_argument( "--IL-batch-size", default=None, type=int, help="The size of mini-batch during the incremental learning process.", ) _il_group.add_argument( "--gamma", "--gamma-main", default=0.1, type=float, help="The regularization of the (main stream) linear classifier.", ) _il_group.add_argument( "--buffer-size", "--expansion-size", default=8192, type=int, help="The buffer size of the classifier.", ) _il_group.add_argument( "--gamma-comp", default=0.1, type=float, help="The regularization of the linear classifier in compensation stream (DS-AL only)", ) _il_group.add_argument( "--sigma", default=10, type=float, help="The width-adjusting of the Gaussian kernel (GKEAL only)", ) _il_group.add_argument( "-C", "--compensation-ratio", default=1, type=float, help="The regularization of the linear classifier in compensation stream (DS-AL only)", ) def load_args() -> Dict[str, Any]: global _parser args = vars(_parser.parse_args()) args["data_root"] = path.expanduser(args["data_root"]) if args["cache_path"] is not None: assert path.isdir(args["cache_path"]), "The cache path is not a directory." args["backbone_path"] = path.join(args["cache_path"], "backbone.pth") assert path.isfile(args["backbone_path"]), \ f"Backbone file \"{args['backbone_path']}\" doesn't exist." saving_root = path.join( "saved_models", f"{args['backbone']}_{args['dataset']}_{args['base_ratio']}_{args['dataset_seed']}", ) args["exp_name"] = args["exp_name"].strip() if args["exp_name"] == "": args["exp_name"] = args["method"] saving_root = path.join(saving_root, args["exp_name"]) if args["IL_batch_size"] is None: args["IL_batch_size"] = args["batch_size"] # Windows does not support ":" in the path current_time = datetime.now().isoformat(timespec="seconds").replace(":", "-") saving_root = path.join(saving_root, current_time) args["saving_root"] = saving_root args["argv"] = str(argv) makedirs(saving_root, exist_ok=True) with open(path.join(saving_root, "args.yaml"), "w", encoding="utf-8") as yaml_file: yaml.safe_dump(args, yaml_file) args["data_root"] = path.join(args["data_root"], args["dataset"]) return args if __name__ == "__main__": print(load_args()) ================================================ FILE: datasets/CIFAR.py ================================================ # -*- coding: utf-8 -*- import torch from torch import Tensor from torchvision.datasets import CIFAR10, CIFAR100 from torchvision.transforms import v2 as transforms from typing import Tuple from .DatasetWrapper import DatasetWrapper class CIFAR10_(DatasetWrapper[Tuple[Tensor, int]]): num_classes = 10 mean = (0.49139967861519607843, 0.48215840839460784314, 0.44653091444546568627) std = (0.21117028181572183225, 0.20857934290628859220, 0.21205155387102001073) basic_transform = transforms.Compose( [ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) augment_transform = transforms.Compose( [ transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(), transforms.TrivialAugmentWide( interpolation=transforms.InterpolationMode.BILINEAR ), transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) def __init__( self, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, ) -> None: self.dataset = CIFAR10(root, train=train, download=True) super().__init__( self.dataset.targets, base_ratio, num_phases, augment, inplace_repeat, shuffle_seed, ) class CIFAR100_(DatasetWrapper[Tuple[Tensor, int]]): num_classes = 100 mean = (0.50707515923713235294, 0.48654887331495098039, 0.44091784336703431373) std = (0.26733428848992695514, 0.25643846542136995765, 0.27615047402246589731) # std = (0.21103932286924015314, 0.20837755491382136483, 0.21551368222930648019) basic_transform = transforms.Compose( [ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) augment_transform = transforms.Compose( [ transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(), transforms.TrivialAugmentWide( interpolation=transforms.InterpolationMode.BILINEAR ), transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) def __init__( self, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, ) -> None: self.dataset = CIFAR100(root, train=train, download=True) super().__init__( self.dataset.targets, base_ratio, num_phases, augment, inplace_repeat, shuffle_seed, ) if __name__ == "__main__": dataset_train = CIFAR100_( "~/.dataset", train=True, base_ratio=0.1, num_phases=3, augment=True ) dataset_test = CIFAR100_( "~/.dataset", train=False, base_ratio=0.1, num_phases=3, augment=False ) for X, y in dataset_train.subset_at_phase(0): assert X.shape == (3, 32, 32) for X, y in dataset_test.subset_at_phase(0): assert X.shape == (3, 32, 32) print("test passed") ================================================ FILE: datasets/DatasetWrapper.py ================================================ # -*- coding: utf-8 -*- from typing import Callable, Iterable, Optional from torch.utils.data import Dataset, Subset try: from torch.utils.data.dataset import T_co except ImportError: from torch.utils._ordered_set import T_co from abc import ABCMeta from random import Random from numpy import repeat from itertools import chain class DatasetWrapper(Dataset[T_co], metaclass=ABCMeta): basic_transform: Callable[[T_co], T_co] augment_transform: Callable[[T_co], T_co] def __init__( self, labels: Iterable[int], base_ratio: float, num_phases: int, augment: bool, inplace_repeat: int = 1, shuffle_seed: Optional[int] = None, ) -> None: # Type hints self.dataset: Dataset[T_co] self.num_classes: int # Initialization super().__init__() self.inplace_repeat = inplace_repeat self.base_ratio = base_ratio self.num_phases = num_phases self.base_size = int(self.num_classes * self.base_ratio) self.incremental_size = self.num_classes - self.base_size self.phase_size = self.incremental_size // num_phases if num_phases > 0 else 0 # Create a list of indices for each class self.class_indices: list[list[int]] = [[] for _ in range(self.num_classes)] for idx, label in enumerate(labels): self.class_indices[label].append(idx) self._transform = self.augment_transform if augment else self.basic_transform # Shuffle the class indices self.real_labels: list[int] = list(range(self.num_classes)) if shuffle_seed is not None: Random(shuffle_seed).shuffle(self.real_labels) Random(shuffle_seed).shuffle(self.class_indices) def __getitem__(self, index: int) -> T_co: return self._transform(self.dataset[index]) def _subset(self, label_begin: int, label_end: int) -> Subset[T_co]: sub_ids = tuple(chain.from_iterable(self.class_indices[label_begin:label_end])) return Subset(self, repeat(sub_ids, self.inplace_repeat).tolist()) def subset_at_phase(self, phase: int) -> Subset[T_co]: if phase == 0: return self._subset(0, self.base_size) return self._subset( self.base_size + (phase - 1) * self.phase_size, self.base_size + phase * self.phase_size, ) def subset_until_phase(self, phase: int) -> Subset[T_co]: return self._subset( 0, self.base_size + phase * self.phase_size, ) ================================================ FILE: datasets/Features.py ================================================ # -*- coding: utf-8 -*- import torch from os import path from .DatasetWrapper import DatasetWrapper from torch.utils.data import TensorDataset from torchvision.transforms import v2 as transforms class Features(DatasetWrapper[tuple[torch.Tensor, torch.LongTensor]]): basic_transform = transforms.Identity() augment_transform = transforms.Identity() def __init__( self, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, ) -> None: assert augment == False, "Augmentation is not supported for Features dataset" if train: X: torch.Tensor = torch.load(path.join(root, "X_train.pt"), weights_only=True) y: torch.Tensor = torch.load(path.join(root, "y_train.pt"), weights_only=True) else: X: torch.Tensor = torch.load(path.join(root, "X_test.pt"), weights_only=True) y: torch.Tensor = torch.load(path.join(root, "y_test.pt"), weights_only=True) y = y.to(torch.long, non_blocking=True) self.dataset = TensorDataset(X, y) # type: ignore self.num_classes = int(y.max().item()) + 1 super().__init__( y.numpy().tolist(), base_ratio, num_phases, False, inplace_repeat, shuffle_seed, ) ================================================ FILE: datasets/ImageNet.py ================================================ # -*- coding: utf-8 -*- from typing import Tuple import torch from .DatasetWrapper import DatasetWrapper from torchvision.datasets import ImageNet from torchvision.transforms import v2 as transforms from os import path class ImageNet_(DatasetWrapper[Tuple[torch.Tensor, int]]): num_classes = 1000 mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) basic_transform = transforms.Compose( [ transforms.Resize(232), transforms.CenterCrop(224), transforms.PILToTensor(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) augment_transform = transforms.Compose( [ transforms.RandomResizedCrop(176), transforms.RandomHorizontalFlip(0.5), transforms.TrivialAugmentWide( interpolation=transforms.InterpolationMode.BILINEAR ), transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.RandomErasing(0.1), transforms.ToPureTensor(), ] ) def __init__( self, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, ) -> None: root = path.expanduser(root) self.dataset = ImageNet(root, split="train" if train else "val") super().__init__( self.dataset.targets, base_ratio, num_phases, augment, inplace_repeat, shuffle_seed, ) ================================================ FILE: datasets/MNIST.py ================================================ # -*- coding: utf-8 -*- import torch from torch import Tensor from torchvision.datasets import MNIST from torchvision.transforms import v2 as transforms from .DatasetWrapper import DatasetWrapper class MNIST_(DatasetWrapper[tuple[Tensor, int]]): num_classes = 10 mean = (0.13066047627384287048,) std = (0.30524474224261827502,) basic_transform = transforms.Compose( [ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), ] ) augment_transform = basic_transform def __init__( self, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, ) -> None: self.dataset = MNIST(root, train=train, download=True) super().__init__( self.dataset.targets.tolist(), base_ratio, num_phases, augment, inplace_repeat, shuffle_seed, ) ================================================ FILE: datasets/__init__.py ================================================ # -*- coding: utf-8 -*- from .DatasetWrapper import DatasetWrapper from .MNIST import MNIST_ as MNIST from .CIFAR import CIFAR10_ as CIFAR10 from .CIFAR import CIFAR100_ as CIFAR100 from .ImageNet import ImageNet_ as ImageNet from typing import Union from .Features import Features __all__ = [ "load_dataset", "dataset_list", "MNIST", "CIFAR10", "CIFAR100", "ImageNet", "DatasetWrapper", "Features", ] dataset_list = { "MNIST": MNIST, "CIFAR-10": CIFAR10, "CIFAR-100": CIFAR100, "ImageNet-1k": ImageNet, } def load_dataset( name: str, root: str, train: bool, base_ratio: float, num_phases: int, augment: bool = False, inplace_repeat: int = 1, shuffle_seed: int | None = None, *args, **kwargs ) -> Union[MNIST, CIFAR10, CIFAR100, ImageNet]: return dataset_list[name]( root=root, train=train, base_ratio=base_ratio, num_phases=num_phases, augment=augment, inplace_repeat=inplace_repeat, shuffle_seed=shuffle_seed, *args, **kwargs ) ================================================ FILE: environment.yaml ================================================ # Usage: conda env create -f environment.yaml name: AL channels: - pytorch - nvidia - conda-forge dependencies: - python=3.11 # PyTorch - pytorch>=2.2 - torchvision - pytorch-cuda # For Nvidia GPU # Necessary Utils - numpy - tqdm - scikit-learn - pip - pip: - prefetch_generator # Optional - black - mypy ================================================ FILE: main.py ================================================ # -*- coding: utf-8 -*- import torch from os import path from tqdm import tqdm from config import load_args, ALL_METHODS from models import load_backbone from typing import Any, Dict, List, Tuple, Optional from datasets import Features, load_dataset from utils import set_determinism, validate from torch._prims_common import DeviceLikeType from torch.utils.data import Dataset, DataLoader def make_dataloader( dataset: Dataset, shuffle: bool = False, batch_size: int = 256, num_workers: int = 8, device: Optional[DeviceLikeType] = None, persistent_workers: bool = False, ) -> DataLoader: pin_memory = (device is not None) and (torch.device(device).type == "cuda") config = { "batch_size": batch_size, "shuffle": shuffle, "num_workers": num_workers, "pin_memory": pin_memory, "pin_memory_device": str(device) if pin_memory else "", "persistent_workers": persistent_workers, } try: from prefetch_generator import BackgroundGenerator class DataLoaderX(DataLoader): def __iter__(self): return BackgroundGenerator(super().__iter__()) return DataLoaderX(dataset, **config) except ImportError: return DataLoader(dataset, **config) def check_cache_features(root: str) -> bool: files_list = ["X_train.pt", "y_train.pt", "X_test.pt", "y_test.pt"] for file in files_list: if not path.isfile(path.join(root, file)): return False return True @torch.no_grad() def cache_features( backbone: torch.nn.Module, dataloader: DataLoader[Tuple[torch.Tensor, torch.Tensor]], device: Optional[DeviceLikeType] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: backbone.eval() X_all: List[torch.Tensor] = [] y_all: List[torch.Tensor] = [] for X, y in tqdm(dataloader, "Caching"): X: torch.Tensor = backbone(X.to(device)) y: torch.Tensor = y.to(torch.int16, non_blocking=True) X_all.append(X.cpu()) y_all.append(y.cpu()) return torch.cat(X_all), torch.cat(y_all) def main(args: Dict[str, Any]): backbone_name = args["backbone"] # Select device if args["cpu_only"] or not torch.cuda.is_available(): main_device = torch.device("cpu") all_gpus = None elif args["gpus"] is not None: gpus = args["gpus"] main_device = torch.device(f"cuda:{gpus[0]}") all_gpus = [torch.device(f"cuda:{gpu}") for gpu in gpus] else: main_device = torch.device("cuda:0") all_gpus = None if args["seed"] is not None: set_determinism(args["seed"]) if "backbone_path" in args: assert path.isfile( args["backbone_path"] ), f"Backbone file \"{args['backbone_path']}\" doesn't exist." preload_backbone = True backbone, _, feature_size = torch.load( args["backbone_path"], map_location=main_device, weights_only=False ) else: # Load model pre-train on ImageNet if there is no base training dataset. preload_backbone = False load_pretrain = args["base_ratio"] == 0 or "ImageNet" not in args["dataset"] backbone, _, feature_size = load_backbone(backbone_name, pretrain=load_pretrain) if load_pretrain: assert args["dataset"] != "ImageNet", "Data may leak!!!" backbone = backbone.to(main_device, non_blocking=True) dataset_args = { "name": args["dataset"], "root": args["data_root"], "base_ratio": args["base_ratio"], "num_phases": args["phases"], "shuffle_seed": args["dataset_seed"] if "dataset_seed" in args else None, } dataset_train = load_dataset(train=True, augment=True, **dataset_args) dataset_test = load_dataset(train=False, augment=False, **dataset_args) # Select algorithm assert args["method"] in ALL_METHODS, f"Unknown method: {args['method']}" learner = ALL_METHODS[args["method"]]( args, backbone, feature_size, main_device, all_devices=all_gpus ) # Base training if args["base_ratio"] > 0 and not preload_backbone: train_subset = dataset_train.subset_at_phase(0) test_subset = dataset_test.subset_at_phase(0) train_loader = make_dataloader( train_subset, True, args["batch_size"], args["num_workers"], device=main_device, ) test_loader = make_dataloader( test_subset, False, args["batch_size"], args["num_workers"], device=main_device, ) learner.base_training( train_loader, test_loader, dataset_train.base_size, ) # Load dataset if args["cache_features"]: if "cache_path" not in args or args["cache_path"] is None: args["cache_path"] = args["saving_root"] if not check_cache_features(args["cache_path"]): backbone = learner.backbone.eval() dataset_train = load_dataset( args["dataset"], args["data_root"], True, 1, 0, augment=False ) dataset_test = load_dataset( args["dataset"], args["data_root"], False, 1, 0, augment=False ) train_loader = make_dataloader( dataset_train.subset_at_phase(0), False, args["batch_size"], args["num_workers"], device=main_device, ) test_loader = make_dataloader( dataset_test.subset_at_phase(0), False, args["batch_size"], args["num_workers"], device=main_device, ) if all_gpus is not None and len(all_gpus) > 1: backbone = torch.nn.DataParallel(backbone, device_ids=all_gpus) X_train, y_train = cache_features( backbone, train_loader, device=main_device ) X_test, y_test = cache_features(backbone, test_loader, device=main_device) torch.save(X_train, path.join(args["cache_path"], "X_train.pt")) torch.save(y_train, path.join(args["cache_path"], "y_train.pt")) torch.save(X_test, path.join(args["cache_path"], "X_test.pt")) torch.save(y_test, path.join(args["cache_path"], "y_test.pt")) dataset_train = Features( args["cache_path"], train=True, base_ratio=args["base_ratio"], num_phases=args["phases"], augment=False, ) dataset_test = Features( args["cache_path"], train=False, base_ratio=args["base_ratio"], num_phases=args["phases"], augment=False, ) learner.backbone = torch.nn.Identity() learner.model.backbone = torch.nn.Identity() else: dataset_train = load_dataset(train=True, augment=False, **dataset_args) dataset_test = load_dataset(train=False, augment=False, **dataset_args) # Incremental learning sum_acc = 0 log_file_path = path.join(args["saving_root"], "IL.csv") log_file = open(log_file_path, "w", buffering=1) print( "phase", "acc@avg", "acc@1", "acc@5", "f1-micro", "loss", file=log_file, sep="," ) for phase in range(0, args["phases"] + 1): train_subset = dataset_train.subset_at_phase(phase) test_subset = dataset_test.subset_until_phase(phase) train_loader = make_dataloader( train_subset, True, args["IL_batch_size"], args["num_workers"], device=main_device, ) test_loader = make_dataloader( test_subset, False, args["IL_batch_size"], args["num_workers"], device=main_device, ) if phase == 0: learner.learn(train_loader, dataset_train.base_size, "Re-align") else: learner.learn(train_loader, dataset_train.phase_size) learner.before_validation() # Validation val_meter = validate( learner, test_loader, dataset_train.num_classes, desc=f"Phase {phase}", ) sum_acc += val_meter.accuracy print( f"loss: {val_meter.loss:.4f}", f"acc@1: {val_meter.accuracy * 100:.3f}%", f"acc@5: {val_meter.accuracy5 * 100:.3f}%", f"f1-micro: {val_meter.f1_micro * 100:.3f}%", f"acc@avg: {sum_acc / (phase + 1) * 100:.3f}%", sep=" ", ) print( phase, sum_acc / (phase + 1), val_meter.accuracy, val_meter.accuracy5, val_meter.f1_micro, val_meter.loss, file=log_file, sep=",", ) log_file.close() if __name__ == "__main__": main(load_args()) ================================================ FILE: models/CifarResNet.py ================================================ # -*- coding: utf-8 -*- """ Properly implemented ResNet-s for CIFAR10 as described in paper [1]. The implementation and structure of this file is hugely influenced by [2] which is implemented for ImageNet and doesn't have option A for identity. Moreover, most of the implementations on the web is copy-paste from torchvision's resnet and has wrong number of params. Proper ResNet-s for CIFAR10 (for fair comparison and etc.) has following number of layers and parameters: name | layers | params ResNet20 | 20 | 0.27M ResNet32 | 32 | 0.46M ResNet44 | 44 | 0.66M ResNet56 | 56 | 0.85M ResNet110 | 110 | 1.7M ResNet1202| 1202 | 19.4m which this implementation indeed has. Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385 [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py If you use this implementation in you work, please don't forget to mention the author, Yerlan Idelbayev. """ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init __all__ = [ "CifarResNet", "resnet20", "resnet32", "resnet44", "resnet56", "resnet110", "resnet1202", ] def _weights_init(m): if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight) class ShortcutA(nn.Module): def __init__(self, planes) -> None: super().__init__() self.pad = planes // 4 def forward(self, x: torch.Tensor) -> torch.Tensor: return F.pad( x[:, :, ::2, ::2], (0, 0, 0, 0, self.pad, self.pad), "constant", 0, ) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, option="A"): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: if option == "A": """ For CIFAR10 ResNet paper uses option A. """ self.shortcut = ShortcutA(planes) elif option == "B": self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(self.expansion * planes), ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class CifarResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(CifarResNet, self).__init__() self.in_planes = 16 self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(16) self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) self.fc = nn.Linear(64, num_classes) self.apply(_weights_init) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = F.avg_pool2d(out, out.size()[3]) out = out.view(out.size(0), -1) out = self.fc(out) return out def resnet20(num_classes=10): return CifarResNet(BasicBlock, [3, 3, 3], num_classes) def resnet32(num_classes=10): return CifarResNet(BasicBlock, [5, 5, 5], num_classes) def resnet44(num_classes=10): return CifarResNet(BasicBlock, [7, 7, 7], num_classes) def resnet56(num_classes=10): return CifarResNet(BasicBlock, [9, 9, 9], num_classes) def resnet110(num_classes=10): return CifarResNet(BasicBlock, [18, 18, 18], num_classes) def resnet1202(num_classes=10): return CifarResNet(BasicBlock, [200, 200, 200], num_classes) def calc_num_params(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": print(calc_num_params(resnet32())) ================================================ FILE: models/__init__.py ================================================ # -*- coding: utf-8 -*- import torch from typing import Dict, Tuple, Union, Optional, Callable from torchvision.models import WeightsEnum from torch.nn import Flatten from torchvision.models.resnet import ( ResNet, resnet18, resnet34, resnet50, resnet101, resnet152, ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights, ) from .CifarResNet import ( CifarResNet, resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202, ) from torchvision.models.vision_transformer import ( VisionTransformer, vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14, ViT_B_16_Weights, ViT_B_32_Weights, ViT_L_16_Weights, ViT_L_32_Weights, ViT_H_14_Weights, ) __all__ = [ "ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "CifarResNet", "resnet20", "resnet32", "resnet44", "resnet56", "resnet110", "resnet1202", "VisionTransformer", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", "vit_h_14", "load_backbone", ] # fmt: off models: Dict[str, Tuple[ int, # Input image size Callable[[], Union[CifarResNet, ResNet, VisionTransformer, Flatten]], # Model constructor Optional[WeightsEnum] ] ] = { # MNIST: No backbone "Flatten": (28, Flatten, None), # ResNet for CIFAR "resnet20": (32, resnet20 , None), "resnet32": (32, resnet32 , None), "resnet44": (32, resnet44 , None), "resnet56": (32, resnet56 , None), "resnet110": (32, resnet110 , None), "resnet1202": (32, resnet1202, None), # ResNet for ImageNet "resnet18": (224, resnet18, ResNet18_Weights.DEFAULT ), "resnet34": (224, resnet34, ResNet34_Weights.DEFAULT ), "resnet50": (224, resnet50, ResNet50_Weights.DEFAULT ), "resnet101": (224, resnet101, ResNet101_Weights.DEFAULT), "resnet152": (224, resnet152, ResNet152_Weights.DEFAULT), # Vision Transformer for ImageNet "vit_b_16": (384, vit_b_16, ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1), "vit_b_32": (224, vit_b_32, ViT_B_32_Weights.IMAGENET1K_V1 ), "vit_l_16": (512, vit_l_16, ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1), "vit_l_32": (224, vit_l_32, ViT_L_32_Weights.IMAGENET1K_V1 ), "vit_h_14": (518, vit_h_14, ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1), } # fmt: on def load_backbone( name: str, pretrain: bool = False, *args, **kwargs ) -> Tuple[torch.nn.Module, int, int]: input_img_size, model, weights = models[name] if pretrain and (weights is not None) and ("weights" not in kwargs): kwargs["weights"] = weights backbone = model(*args, **kwargs) if isinstance(backbone, VisionTransformer): feature_size: int = backbone.heads[-1].in_features backbone.heads = torch.nn.Identity() # type: ignore elif isinstance(backbone, (ResNet, CifarResNet)): feature_size = backbone.fc.in_features backbone.fc = torch.nn.Identity() # type: ignore elif isinstance(backbone, Flatten): feature_size = input_img_size ** 2 return backbone, input_img_size, feature_size if __name__ == "__main__": for name in models.keys(): backbone, input_img_size, feature_size = load_backbone(name, pretrain=True) test_img = torch.randn((1, 3, input_img_size, input_img_size)) prototype: torch.Tensor = backbone(test_img) assert len(prototype.shape) == 2 and prototype.shape[0] == 1 assert feature_size == prototype.shape[1] ================================================ FILE: utils/__init__.py ================================================ # -*- coding: utf-8 -*- from .validate import validate from .set_weight_decay import set_weight_decay from .set_determinism import set_determinism from .metrics import ClassificationMeter __all__ = ["validate", "set_weight_decay", "set_determinism", "ClassificationMeter"] ================================================ FILE: utils/metrics.py ================================================ import torch import numpy as np from sklearn import metrics class ClassificationMeter: def __init__(self, num_classes: int, record_logits: bool = False) -> None: self.num_classes = num_classes self.total_loss = 0.0 self.labels = np.zeros((0,), dtype=np.int32) self.prediction = np.zeros((0,), dtype=np.int32) self.acc5_cnt = 0 self.record_logits = record_logits if self.record_logits: self.logits = np.ndarray((0, num_classes)) def record(self, y_true: torch.Tensor, logits: torch.Tensor) -> None: self.labels = np.concatenate([self.labels, y_true.cpu().numpy()]) # Record logits if self.record_logits: logits_softmax = torch.nn.functional.softmax(logits, dim=1).cpu().numpy() self.logits = np.concatenate([self.logits, logits_softmax]) # Loss self.total_loss += float( torch.nn.functional.cross_entropy(logits, y_true, reduction="sum").item() ) # Top-5 accuracy y_pred = logits.topk(5, largest=True).indices.to(torch.int) acc5_judge = (y_pred == y_true[:, None]).any(dim=-1) self.acc5_cnt += int(acc5_judge.sum().item()) # Record the predictions self.prediction = np.concatenate([self.prediction, y_pred[:, 0].cpu().numpy()]) @property def accuracy(self) -> float: return float(metrics.accuracy_score(self.labels, self.prediction)) @property def balanced_accuracy(self) -> float: result = metrics.balanced_accuracy_score( self.labels, self.prediction, adjusted=True ) return float(result) @property def f1_micro(self) -> float: result = metrics.f1_score(self.labels, self.prediction, average="micro") return float(result) @property def f1_macro(self) -> float: result = metrics.f1_score(self.labels, self.prediction, average="macro") return float(result) @property def accuracy5(self) -> float: return self.acc5_cnt / len(self.labels) @property def loss(self) -> float: return float(self.total_loss / len(self.labels)) ================================================ FILE: utils/set_determinism.py ================================================ # -*- coding: utf-8 -*- import torch import numpy from os import environ import random def set_determinism(seed: int) -> None: environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) torch.manual_seed(seed) numpy.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.cuda.manual_seed_all(seed) ================================================ FILE: utils/set_weight_decay.py ================================================ # -*- coding: utf-8 -*- from gc import disable import torch from typing import Dict, List, Any def set_weight_decay( model: torch.nn.Module, weight_decay: float, disable_norm_decay: bool = True, disable_bias_decay: bool = True, disable_embedding_decay: bool = True, ): # See: https://github.com/pytorch/vision/blob/main/references/classification/utils.py norm_classes = ( torch.nn.modules.batchnorm._BatchNorm, torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.modules.instancenorm._InstanceNorm, torch.nn.LocalResponseNorm, ) params = { "other": [], "norm": [], "bias": [], "class_token": [], "position_embedding": [], "relative_position_bias_table": [], } params_weight_decay = { "bias": 0 if disable_bias_decay else weight_decay, "class_token": 0 if disable_embedding_decay else weight_decay, "position_embedding": 0 if disable_embedding_decay else weight_decay, "relative_position_bias_table": 0 if disable_embedding_decay else weight_decay, } def _add_params(module: torch.nn.Module, prefix=""): for name, p in module.named_parameters(recurse=False): for key in params_weight_decay.keys(): target_name = ( f"{prefix}.{name}" if prefix != "" and "." in key else name ) if key == target_name: params[key].append(p) break else: if isinstance(module, norm_classes): params["norm"].append(p) else: params["other"].append(p) for child_name, child_module in module.named_children(): child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name _add_params(child_module, prefix=child_prefix) _add_params(model) params_weight_decay["other"] = weight_decay params_weight_decay["norm"] = 0.0 if disable_norm_decay else weight_decay param_groups: List = [] for key in params: if len(params[key]) > 0: param_groups.append( {"params": params[key], "weight_decay": params_weight_decay[key]} ) return param_groups ================================================ FILE: utils/validate.py ================================================ # -*- coding: utf-8 -*- import torch from tqdm import tqdm from typing import Tuple, Iterable, Optional, Callable from .metrics import ClassificationMeter @torch.no_grad() def validate( model: Callable[[torch.Tensor], torch.Tensor], data_loader: Iterable[Tuple[torch.Tensor, torch.Tensor]], num_classes: int, desc: Optional[str] = None ) -> ClassificationMeter: if isinstance(model, torch.nn.Module): model.eval() device = next(model.parameters()).device else: device = model.device meter = ClassificationMeter(num_classes) for X, y in tqdm(data_loader, desc=desc): X = X.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Calculate the loss logits: torch.Tensor = model(X) meter.record(y, logits) return meter