Repository: BeierZhu/Prompt-align Branch: main Commit: 1b762036cbe9 Files: 244 Total size: 606.4 KB Directory structure: gitextract_cxcff6dn/ ├── Dassl.ProGrad.pytorch/ │ ├── .flake8 │ ├── .gitignore │ ├── .isort.cfg │ ├── .style.yapf │ ├── DATASETS.md │ ├── LICENSE │ ├── README.md │ ├── configs/ │ │ ├── README.md │ │ ├── datasets/ │ │ │ ├── da/ │ │ │ │ ├── cifar_stl.yaml │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ ├── mini_domainnet.yaml │ │ │ │ ├── office31.yaml │ │ │ │ ├── office_home.yaml │ │ │ │ └── visda17.yaml │ │ │ ├── dg/ │ │ │ │ ├── cifar100_c.yaml │ │ │ │ ├── cifar10_c.yaml │ │ │ │ ├── digit_single.yaml │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ ├── pacs.yaml │ │ │ │ └── vlcs.yaml │ │ │ └── ssl/ │ │ │ ├── cifar10.yaml │ │ │ ├── cifar100.yaml │ │ │ ├── stl10.yaml │ │ │ └── svhn.yaml │ │ └── trainers/ │ │ ├── da/ │ │ │ ├── dael/ │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ └── mini_domainnet.yaml │ │ │ ├── m3sda/ │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ └── mini_domainnet.yaml │ │ │ └── source_only/ │ │ │ ├── digit5.yaml │ │ │ ├── mini_domainnet.yaml │ │ │ ├── office31.yaml │ │ │ └── visda17.yaml │ │ ├── dg/ │ │ │ ├── dael/ │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ └── pacs.yaml │ │ │ ├── ddaig/ │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ └── pacs.yaml │ │ │ └── vanilla/ │ │ │ ├── digits_dg.yaml │ │ │ ├── mini_domainnet.yaml │ │ │ ├── office_home_dg.yaml │ │ │ └── pacs.yaml │ │ └── ssl/ │ │ └── fixmatch/ │ │ └── cifar10.yaml │ ├── dassl/ │ │ ├── __init__.py │ │ ├── config/ │ │ │ ├── __init__.py │ │ │ └── defaults.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── data_manager.py │ │ │ ├── datasets/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_dataset.py │ │ │ │ ├── build.py │ │ │ │ ├── da/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifarstl.py │ │ │ │ │ ├── digit5.py │ │ │ │ │ ├── domainnet.py │ │ │ │ │ ├── mini_domainnet.py │ │ │ │ │ ├── office31.py │ │ │ │ │ ├── office_home.py │ │ │ │ │ └── visda17.py │ │ │ │ ├── dg/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifar_c.py │ │ │ │ │ ├── digit_single.py │ │ │ │ │ ├── digits_dg.py │ │ │ │ │ ├── office_home_dg.py │ │ │ │ │ ├── pacs.py │ │ │ │ │ └── vlcs.py │ │ │ │ └── ssl/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cifar.py │ │ │ │ ├── stl10.py │ │ │ │ └── svhn.py │ │ │ ├── samplers.py │ │ │ └── transforms/ │ │ │ ├── __init__.py │ │ │ ├── autoaugment.py │ │ │ ├── randaugment.py │ │ │ └── transforms.py │ │ ├── engine/ │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── da/ │ │ │ │ ├── __init__.py │ │ │ │ ├── adabn.py │ │ │ │ ├── adda.py │ │ │ │ ├── dael.py │ │ │ │ ├── dann.py │ │ │ │ ├── m3sda.py │ │ │ │ ├── mcd.py │ │ │ │ ├── mme.py │ │ │ │ ├── self_ensembling.py │ │ │ │ └── source_only.py │ │ │ ├── dg/ │ │ │ │ ├── __init__.py │ │ │ │ ├── crossgrad.py │ │ │ │ ├── daeldg.py │ │ │ │ ├── ddaig.py │ │ │ │ └── vanilla.py │ │ │ ├── ssl/ │ │ │ │ ├── __init__.py │ │ │ │ ├── entmin.py │ │ │ │ ├── fixmatch.py │ │ │ │ ├── mean_teacher.py │ │ │ │ ├── mixmatch.py │ │ │ │ └── sup_baseline.py │ │ │ └── trainer.py │ │ ├── evaluation/ │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── evaluator.py │ │ ├── metrics/ │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ └── distance.py │ │ ├── modeling/ │ │ │ ├── __init__.py │ │ │ ├── backbone/ │ │ │ │ ├── __init__.py │ │ │ │ ├── alexnet.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── cnn_digit5_m3sda.py │ │ │ │ ├── cnn_digitsdg.py │ │ │ │ ├── cnn_digitsingle.py │ │ │ │ ├── efficientnet/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model.py │ │ │ │ │ └── utils.py │ │ │ │ ├── mobilenetv2.py │ │ │ │ ├── preact_resnet18.py │ │ │ │ ├── resnet.py │ │ │ │ ├── shufflenetv2.py │ │ │ │ ├── vgg.py │ │ │ │ └── wide_resnet.py │ │ │ ├── head/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── mlp.py │ │ │ ├── network/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── ddaig_fcn.py │ │ │ └── ops/ │ │ │ ├── __init__.py │ │ │ ├── cross_entropy.py │ │ │ ├── dsbn.py │ │ │ ├── efdmix.py │ │ │ ├── mixstyle.py │ │ │ ├── mixup.py │ │ │ ├── mmd.py │ │ │ ├── optimal_transport.py │ │ │ ├── reverse_grad.py │ │ │ ├── sequential2.py │ │ │ ├── transnorm.py │ │ │ └── utils.py │ │ ├── optim/ │ │ │ ├── __init__.py │ │ │ ├── lr_scheduler.py │ │ │ ├── optimizer.py │ │ │ └── radam.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── logger.py │ │ ├── meters.py │ │ ├── registry.py │ │ ├── tools.py │ │ └── torchtools.py │ ├── datasets/ │ │ ├── da/ │ │ │ ├── cifar_stl.py │ │ │ ├── digit5.py │ │ │ └── visda17.sh │ │ ├── dg/ │ │ │ └── cifar_c.py │ │ └── ssl/ │ │ ├── cifar10_cifar100_svhn.py │ │ └── stl10.py │ ├── linter.sh │ ├── requirements.txt │ ├── setup.py │ └── tools/ │ ├── parse_test_res.py │ ├── replace_text.py │ └── train.py ├── ProGrad.public/ │ ├── .gitignore │ ├── DATASETS.md │ ├── LICENSE │ ├── README.md │ ├── clip/ │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── configs/ │ │ ├── datasets/ │ │ │ ├── caltech101.yaml │ │ │ ├── dtd.yaml │ │ │ ├── eurosat.yaml │ │ │ ├── fgvc_aircraft.yaml │ │ │ ├── food101.yaml │ │ │ ├── imagenet.yaml │ │ │ ├── imagenet_a.yaml │ │ │ ├── imagenet_r.yaml │ │ │ ├── imagenet_sketch.yaml │ │ │ ├── imagenetv2.yaml │ │ │ ├── oxford_flowers.yaml │ │ │ ├── oxford_pets.yaml │ │ │ ├── stanford_cars.yaml │ │ │ ├── sun397.yaml │ │ │ └── ucf101.yaml │ │ └── trainers/ │ │ ├── CoCoOp/ │ │ │ ├── rn50_c4_ep10_batch1_ctxv1.yaml │ │ │ ├── rn50_ep100_init.yaml │ │ │ ├── rn50_ep50.yaml │ │ │ ├── vit_b16_c16_ep10_batch1.yaml │ │ │ ├── vit_b16_c4_ep10_batch1.yaml │ │ │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ │ │ └── vit_b16_c8_ep10_batch1.yaml │ │ ├── CoOp/ │ │ │ ├── rn50.yaml │ │ │ ├── rn50_ep100.yaml │ │ │ ├── rn50_ep50.yaml │ │ │ └── rn50_val.yaml │ │ └── ProGrad/ │ │ ├── rn50.yaml │ │ ├── rn50_ep100.yaml │ │ └── rn50_ep50.yaml │ ├── datasets/ │ │ ├── __init__.py │ │ ├── caltech101.py │ │ ├── dtd.py │ │ ├── eurosat.py │ │ ├── fgvc_aircraft.py │ │ ├── food101.py │ │ ├── imagenet.py │ │ ├── imagenet_a.py │ │ ├── imagenet_r.py │ │ ├── imagenet_sketch.py │ │ ├── imagenetv2.py │ │ ├── oxford_flowers.py │ │ ├── oxford_pets.py │ │ ├── stanford_cars.py │ │ ├── sun397.py │ │ └── ucf101.py │ ├── interpret_prompt.py │ ├── lpclip/ │ │ ├── README.md │ │ ├── feat_extractor.py │ │ ├── feat_extractor.sh │ │ ├── linear_probe.py │ │ ├── linear_probe.sh │ │ └── linear_probe_transfer.py │ ├── parse_test_res.py │ ├── requirements.txt │ ├── scripts/ │ │ ├── base2new_test_main.sh │ │ ├── base2new_test_prograd.sh │ │ ├── base2new_train_main.sh │ │ ├── base2new_train_prograd.sh │ │ ├── eval.sh │ │ ├── main.sh │ │ ├── prograd.sh │ │ └── zeroshot.sh │ ├── train.py │ └── trainers/ │ ├── __init__.py │ ├── cocoop.py │ ├── coop.py │ ├── imagenet_templates.py │ ├── prograd.py │ └── zsclip.py └── readme.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: Dassl.ProGrad.pytorch/.flake8 ================================================ [flake8] ignore = # At least two spaces before inline comment E261, # Line lengths are recommended to be no greater than 79 characters E501, # Missing whitespace around arithmetic operator E226, # Blank line contains whitespace W293, # Do not use bare 'except' E722, # Line break after binary operator W504, # Too many leading '#' for block comment E266, # line break before binary operator W503, # continuation line over-indented for hanging indent E126 max-line-length = 79 exclude = __init__.py, build ================================================ FILE: Dassl.ProGrad.pytorch/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # OS X .DS_Store .Spotlight-V100 .Trashes ._* # This project output/ debug.sh debug.py ================================================ FILE: Dassl.ProGrad.pytorch/.isort.cfg ================================================ [isort] line_length=79 multi_line_output=6 length_sort=true known_standard_library=numpy,setuptools known_myself=dassl known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown no_lines_before=STDLIB,THIRDPARTY sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER default_section=FIRSTPARTY ================================================ FILE: Dassl.ProGrad.pytorch/.style.yapf ================================================ [style] BASED_ON_STYLE = pep8 BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true DEDENT_CLOSING_BRACKETS = true SPACES_BEFORE_COMMENT = 2 ARITHMETIC_PRECEDENCE_INDICATION = true ================================================ FILE: Dassl.ProGrad.pytorch/DATASETS.md ================================================ # How to Install Datasets `$DATA` denotes the location where datasets are installed, e.g. ``` $DATA/ |–– office31/ |–– office_home/ |–– visda17/ ``` [Domain Adaptation](#domain-adaptation) - [Office-31](#office-31) - [Office-Home](#office-home) - [VisDA17](#visda17) - [CIFAR10-STL10](#cifar10-stl10) - [Digit-5](#digit-5) - [DomainNet](#domainnet) - [miniDomainNet](#miniDomainNet) [Domain Generalization](#domain-generalization) - [PACS](#pacs) - [VLCS](#vlcs) - [Office-Home-DG](#office-home-dg) - [Digits-DG](#digits-dg) - [Digit-Single](#digit-single) - [CIFAR-10-C](#cifar-10-c) - [CIFAR-100-C](#cifar-100-c) [Semi-Supervised Learning](#semi-supervised-learning) - [CIFAR10/100 and SVHN](#cifar10100-and-svhn) - [STL10](#stl10) ## Domain Adaptation ### Office-31 Download link: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code. File structure: ``` office31/ |–– amazon/ | |–– back_pack/ | |–– bike/ | |–– ... |–– dslr/ | |–– back_pack/ | |–– bike/ | |–– ... |–– webcam/ | |–– back_pack/ | |–– bike/ | |–– ... ``` Note that within each domain folder you need to move all class folders out of the `images/` folder and then delete the `images/` folder. ### Office-Home Download link: http://hemanthdv.org/OfficeHome-Dataset/. File structure: ``` office_home/ |–– art/ |–– clipart/ |–– product/ |–– real_world/ ``` ### VisDA17 Download link: http://ai.bu.edu/visda-2017/. The dataset can also be downloaded using our script at `datasets/da/visda17.sh`. Run the following command in your terminal under `Dassl.pytorch/datasets/da`, ```bash sh visda17.sh $DATA ``` Once the download is finished, the file structure will look like ``` visda17/ |–– train/ |–– test/ |–– validation/ ``` ### CIFAR10-STL10 Run the following command in your terminal under `Dassl.pytorch/datasets/da`, ```bash python cifar_stl.py $DATA/cifar_stl ``` This will create a folder named `cifar_stl` under `$DATA`. The file structure will look like ``` cifar_stl/ |–– cifar/ | |–– train/ | |–– test/ |–– stl/ | |–– train/ | |–– test/ ``` Note that only 9 classes shared by both datasets are kept. ### Digit-5 Create a folder `$DATA/digit5` and download to this folder the dataset from [here](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download). This should give you ``` digit5/ |–– Digit-Five/ ``` Then, run the following command in your terminal under `Dassl.pytorch/datasets/da`, ```bash python digit5.py $DATA/digit5 ``` This will extract the data and organize the file structure as ``` digit5/ |–– Digit-Five/ |–– mnist/ |–– mnist_m/ |–– usps/ |–– svhn/ |–– syn/ ``` ### DomainNet Download link: http://ai.bu.edu/M3SDA/. (Please download the cleaned version of split files) File structure: ``` domainnet/ |–– clipart/ |–– infograph/ |–– painting/ |–– quickdraw/ |–– real/ |–– sketch/ |–– splits/ | |–– clipart_train.txt | |–– clipart_test.txt | |–– ... ``` ### miniDomainNet You need to download the DomainNet dataset first. The miniDomainNet's split files can be downloaded at this [google drive](https://drive.google.com/open?id=15rrLDCrzyi6ZY-1vJar3u7plgLe4COL7). After the zip file is extracted, you should have the folder `$DATA/domainnet/splits_mini/`. ## Domain Generalization ### PACS Download link: [google drive](https://drive.google.com/open?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE). File structure: ``` pacs/ |–– images/ |–– splits/ ``` You do not necessarily have to manually download this dataset. Once you run ``tools/train.py``, the code will detect if the dataset exists or not and automatically download the dataset to ``$DATA`` if missing. This also applies to VLCS, Office-Home-DG, and Digits-DG. ### VLCS Download link: [google drive](https://drive.google.com/file/d/1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd/view?usp=sharing) (credit to https://github.com/fmcarlucci/JigenDG#vlcs) File structure: ``` VLCS/ |–– CALTECH/ |–– LABELME/ |–– PASCAL/ |–– SUN/ ``` ### Office-Home-DG Download link: [google drive](https://drive.google.com/open?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa). File structure: ``` office_home_dg/ |–– art/ |–– clipart/ |–– product/ |–– real_world/ ``` ### Digits-DG Download link: [google driv](https://drive.google.com/open?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7). File structure: ``` digits_dg/ |–– mnist/ |–– mnist_m/ |–– svhn/ |–– syn/ ``` ### Digit-Single Follow the steps for [Digit-5](#digit-5) to organize the dataset. ### CIFAR-10-C First download the CIFAR-10-C dataset from https://zenodo.org/record/2535967#.YFxHEWQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal ```bash python cifar_c.py $DATA/CIFAR-10-C ``` where the first argument denotes the path to the (uncompressed) CIFAR-10-C dataset. The script will extract images from the `.npy` files and save them to `cifar10_c/` created under $DATA. The file structure will look like ``` cifar10_c/ |–– brightness/ | |–– 1/ # 5 intensity levels in total | |–– 2/ | |–– 3/ | |–– 4/ | |–– 5/ |–– ... # 19 corruption types in total ``` Note that `cifar10_c/` only contains the test images. The training images are the normal CIFAR-10 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-10 dataset. ### CIFAR-100-C First download the CIFAR-100-C dataset from https://zenodo.org/record/3555552#.YFxpQmQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal ```bash python cifar_c.py $DATA/CIFAR-100-C ``` where the first argument denotes the path to the (uncompressed) CIFAR-100-C dataset. The script will extract images from the `.npy` files and save them to `cifar100_c/` created under $DATA. The file structure will look like ``` cifar100_c/ |–– brightness/ | |–– 1/ # 5 intensity levels in total | |–– 2/ | |–– 3/ | |–– 4/ | |–– 5/ |–– ... # 19 corruption types in total ``` Note that `cifar100_c/` only contains the test images. The training images are the normal CIFAR-100 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-100 dataset. ## Semi-Supervised Learning ### CIFAR10/100 and SVHN Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`, ```bash python cifar10_cifar100_svhn.py $DATA ``` This will create three folders under `$DATA`, i.e. ``` cifar10/ |–– train/ |–– test/ cifar100/ |–– train/ |–– test/ svhn/ |–– train/ |–– test/ ``` ### STL10 Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`, ```bash python stl10.py $DATA/stl10 ``` This will create a folder named `stl10` under `$DATA` and extract the data into three folders, i.e. `train`, `test` and `unlabeled`. Then, download from http://ai.stanford.edu/~acoates/stl10/ the "Binary files" and extract it under `stl10`. The file structure will look like ``` stl10/ |–– train/ |–– test/ |–– unlabeled/ |–– stl10_binary/ ``` ================================================ FILE: Dassl.ProGrad.pytorch/LICENSE ================================================ MIT License Copyright (c) 2020 Kaiyang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Dassl.ProGrad.pytorch/README.md ================================================ # Dassl ## Introduction Dassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain Adaptive Ensemble Learning (DAEL)](https://arxiv.org/abs/2003.07325) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter. Why the name "Dassl"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative. Dassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the [engine](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl/engine) folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-) Basically, Dassl is perfect for doing research in the following areas: - Domain adaptation - Domain generalization - Semi-supervised learning BUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects, like [this](https://github.com/KaiyangZhou/CoOp). :-) A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training (Dassl uses `DataParallel` to wrap a model, which is less efficient than `DistributedDataParallel`). We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-) ## What's new - Mar 2022: A new domain generalization method [EFDM](https://arxiv.org/abs/2203.07740) developed by [Yabin Zhang (PolyU)](https://ybzh.github.io/) and to appear at CVPR'22 is added to this repo. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/pull/36) for more details. - Feb 2022: In case you don't know, a class in the painting domain of DomainNet (the official splits) only has test images (no training images), which could affect performance. See section 4.a in our [paper](https://arxiv.org/abs/2003.07325) for more details. - Oct 2021: `v0.5.0`: **Important changes** made to `transforms.py`. 1) `center_crop` becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio). 2) For training, `Resize(cfg.INPUT.SIZE)` is deactivated when `random_crop` or `random_resized_crop` is used. These changes won't make any difference to the training transforms used in existing config files, nor to the testing transforms unless the raw images are not squared (the only difference is that now the image aspect ratio is respected). - Oct 2021: `v0.4.3`: Copy the attributes in `self.dm` (data manager) to `SimpleTrainer` and make `self.dm` optional, which means from now on, you can build data loaders from any source you like rather than being forced to use `DataManager`. - Sep 2021: `v0.4.2`: An important update is to set `drop_last=is_train and len(data_source)>=batch_size` when constructing a data loader to avoid 0-length.
More - Aug 2021: `v0.4.0`: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py#L10) and the config variables are specified [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L171). - Jul 2021: `v0.3.4`: Adds a new function `generate_fewshot_dataset()` to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying `_C.DATASET.NUM_SHOTS` and give it to `generate_fewshot_dataset()`. - Jul 2021: `v0.3.2`: Adds `_C.INPUT.INTERPOLATION` (default: `bilinear`). Available interpolation modes are `bilinear`, `nearest`, and `bicubic`. - Jul 2021 `v0.3.1`: Now you can use `*.register(force=True)` to replace previously registered modules. - Jul 2021 `v0.3.0`: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named `_C.TEST.FINAL_MODEL` is introduced, which takes either `"last_step"` (default) or `"best_val"`. When set to `"best_val"`, the model will be evaluated on the `val` set after each epoch and the one with the best validation performance will be saved and used for final test (see this [code](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/trainer.py#L412)). - Jul 2021 `v0.2.7`: Adds attribute `classnames` to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling `trainer.dm.dataset.classnames`. - Jun 2021 `v0.2.6`: Merges `MixStyle2` to `MixStyle`. A new variable `self.mix` is used to switch between random mixing and cross-domain mixing. Please see [this](https://github.com/KaiyangZhou/Dassl.pytorch/issues/23) for more details on the new features. - Jun 2021 `v0.2.5`: Fixs a [bug](https://github.com/KaiyangZhou/Dassl.pytorch/commit/29881c7faee7405f80f5f674de4bbbf80d5dc77a) in the calculation of per-class recognition accuracy. - Jun 2021 `v0.2.4`: Adds `extend_cfg(cfg)` to `train.py`. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository [mixstyle-release](https://github.com/KaiyangZhou/mixstyle-release) or [ssdg-benchmark](https://github.com/KaiyangZhou/ssdg-benchmark) for examples. - Jun 2021 New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark. - Apr 2021 Do you know you can use `tools/parse_test_res.py` to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in `tools/parse_test_res.py` for more details. - Apr 2021 `v0.2.3`: A [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) layer can now be deactivated or activated by using `model.apply(deactivate_mixstyle)` or `model.apply(activate_mixstyle)` without modifying the source code. See [dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py) for the details. - Apr 2021 `v0.2.2`: Adds `RandomClassSampler`, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from [Torchreid](https://github.com/KaiyangZhou/deep-person-reid)). - Apr 2021 `v0.2.1`: Slightly adjusts the ordering in `setup_cfg()` (see `tools/train.py`). - Apr 2021 `v0.2.0`: Adds `_C.DATASET.ALL_AS_UNLABELED` (for the SSL setting) to the config variable list. When this variable is set to `True`, all labeled data will be included in the unlabeled data set. - Apr 2021 `v0.1.9`: Adds [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) to the benchmark datasets (see `dassl/data/datasets/dg/vlcs.py`). - Mar 2021 `v0.1.8`: Allows `optim` and `sched` to be `None` in `register_model()`. - Mar 2021 `v0.1.7`: Adds [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) models to [dassl/modeling/backbone/resnet.py](dassl/modeling/backbone/resnet.py). The training configs in `configs/trainers/dg/vanilla` can be used to train MixStyle models. - Mar 2021 `v0.1.6`: Adds [CIFAR-10/100-C](https://arxiv.org/abs/1807.01697) to the benchmark datasets for evaluating a model's robustness to image corruptions. - Mar 2021 We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on. - Jan 2021 Our recent work, [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch. - May 2020 `v0.1.3`: Adds the `Digit-Single` dataset for benchmarking single-source DG methods. The corresponding CNN model is [dassl/modeling/backbone/cnn_digitsingle.py](dassl/modeling/backbone/cnn_digitsingle.py) and the dataset config file is [configs/datasets/dg/digit_single.yaml](configs/datasets/dg/digit_single.yaml). See [Volpi et al. NIPS'18](https://arxiv.org/abs/1805.12018) for how to do evaluation. - May 2020 `v0.1.2`: 1) Adds [EfficientNet](https://arxiv.org/abs/1905.11946) models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set `MODEL.BACKBONE.NAME` to `efficientnet_b{N}` where `N={0, ..., 7}`. 2) `dassl/modeling/models` is renamed to `dassl/modeling/network` (`build_model()` to `build_network()` and `MODEL_REGISTRY` to `NETWORK_RESIGTRY`).
## Overview Dassl has implemented the following methods: - Single-source domain adaptation - [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)] - [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)] - [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)] - [Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)](https://arxiv.org/abs/1603.04779) [[dassl/engine/da/adabn.py](dassl/engine/da/adabn.py)] - [Adversarial Discriminative Domain Adaptation (CVPR'17)](https://arxiv.org/abs/1702.05464) [[dassl/engine/da/adda.py](dassl/engine/da/adda.py)] - [Domain-Adversarial Training of Neural Networks (JMLR'16) ](https://arxiv.org/abs/1505.07818) [[dassl/engine/da/dann.py](dassl/engine/da/dann.py)] - Multi-source domain adaptation - [Domain Aadaptive Ensemble Learning](https://arxiv.org/abs/2003.07325) [[dassl/engine/da/dael.py](dassl/engine/da/dael.py)] - [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)] - Domain generalization - [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)] - [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)] - [Deep Domain-Adversarial Image Generation for Domain Generalisation (AAAI'20)](https://arxiv.org/abs/2003.06054) [[dassl/engine/dg/ddaig.py](dassl/engine/dg/ddaig.py)] - [Generalizing Across Domains via Cross-Gradient Training (ICLR'18)](https://arxiv.org/abs/1804.10745) [[dassl/engine/dg/crossgrad.py](dassl/engine/dg/crossgrad.py)] - Semi-supervised learning - [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685) [[dassl/engine/ssl/fixmatch.py](dassl/engine/ssl/fixmatch.py)] - [MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)](https://arxiv.org/abs/1905.02249) [[dassl/engine/ssl/mixmatch.py](dassl/engine/ssl/mixmatch.py)] - [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)](https://arxiv.org/abs/1703.01780) [[dassl/engine/ssl/mean_teacher.py](dassl/engine/ssl/mean_teacher.py)] - [Semi-supervised Learning by Entropy Minimization (NeurIPS'04)](http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf) [[dassl/engine/ssl/entmin.py](dassl/engine/ssl/entmin.py)] *Feel free to make a [PR](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to add your methods here to make it easier for others to benchmark!* Dassl supports the following datasets: - Domain adaptation - [Office-31](https://scalable.mpi-inf.mpg.de/files/2013/04/saenko_eccv_2010.pdf) - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/) - [VisDA17](http://ai.bu.edu/visda-2017/) - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)-[STL10](https://cs.stanford.edu/~acoates/stl10/) - [Digit-5](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download) - [DomainNet](http://ai.bu.edu/M3SDA/) - [miniDomainNet](https://arxiv.org/abs/2003.07325) - Domain generalization - [PACS](https://arxiv.org/abs/1710.03077) - [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/) - [Digits-DG](https://arxiv.org/abs/2003.06054) - [Digit-Single](https://arxiv.org/abs/1805.12018) - [CIFAR-10-C](https://arxiv.org/abs/1807.01697) - [CIFAR-100-C](https://arxiv.org/abs/1807.01697) - Semi-supervised learning - [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html.) - [SVHN](http://ufldl.stanford.edu/housenumbers/) - [STL10](https://cs.stanford.edu/~acoates/stl10/) ## Get started ### Installation Make sure [conda](https://www.anaconda.com/distribution/) is installed properly. ```bash # Clone this repo git clone https://github.com/KaiyangZhou/Dassl.pytorch.git cd Dassl.pytorch/ # Create a conda environment conda create -n dassl python=3.7 # Activate the environment conda activate dassl # Install dependencies pip install -r requirements.txt # Install torch (version >= 1.7.1) and torchvision conda install pytorch torchvision cudatoolkit=10.1 -c pytorch # Install this library (no need to re-build if the source code is modified) python setup.py develop ``` Follow the instructions in [DATASETS.md](./DATASETS.md) to preprocess the datasets. ### Training The main interface is implemented in `tools/train.py`, which basically does 1. initialize the config with `cfg = setup_cfg(args)` where `args` contains the command-line input (see `tools/train.py` for the list of input arguments); 2. instantiate a `trainer` with `build_trainer(cfg)` which loads the dataset and builds a deep neural network model; 3. call `trainer.train()` for training and evaluating the model. Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31, ```bash CUDA_VISIBLE_DEVICES=0 python tools/train.py \ --root $DATA \ --trainer SourceOnly \ --source-domains amazon \ --target-domains webcam \ --dataset-config-file configs/datasets/da/office31.yaml \ --config-file configs/trainers/da/source_only/office31.yaml \ --output-dir output/source_only_office31 ``` `$DATA` denotes the location where datasets are installed. `--dataset-config-file` loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. `--config-file` loads the algorithm-specific setting such as hyper-parameters and optimization parameters. To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to `--source-domains`. For instance, to train a source-only baseline on miniDomainNet, one can do ```bash CUDA_VISIBLE_DEVICES=0 python tools/train.py \ --root $DATA \ --trainer SourceOnly \ --source-domains clipart painting real \ --target-domains sketch \ --dataset-config-file configs/datasets/da/mini_domainnet.yaml \ --config-file configs/trainers/da/source_only/mini_domainnet.yaml \ --output-dir output/source_only_minidn ``` After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization. To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use `tools/parse_test_res.py`. The instruction can be found in the code. For other trainers such as `MCD`, you can set `--trainer MCD` while keeping the config file unchanged, i.e. using the same training parameters as `SourceOnly` (in the simplest case). To modify the hyper-parameters in MCD, like `N_STEP_F` (number of steps to update the feature extractor), you can append `TRAINER.MCD.N_STEP_F 4` to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new `.yaml` config file to store your custom setting. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L176) for a complete list of algorithm-specific hyper-parameters. ### Test Model testing can be done by using `--eval-only`, which asks the code to run `trainer.test()`. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use `model.pth.tar-20` saved at `output/source_only_office31/model`, you can do ```bash CUDA_VISIBLE_DEVICES=0 python tools/train.py \ --root $DATA \ --trainer SourceOnly \ --source-domains amazon \ --target-domains webcam \ --dataset-config-file configs/datasets/da/office31.yaml \ --config-file configs/trainers/da/source_only/office31.yaml \ --output-dir output/source_only_office31_test \ --eval-only \ --model-dir output/source_only_office31 \ --load-epoch 20 ``` Note that `--model-dir` takes as input the directory path which was specified in `--output-dir` in the training stage. ### Write a new trainer A good practice is to go through `dassl/engine/trainer.py` to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass `TrainerXU`. For domain generalization, the new class can subclass `TrainerX`. In particular, `TrainerXU` and `TrainerX` mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the `forward_backward()` method, which performs loss computation and model update. See `dassl/enigne/da/source_only.py` for example. ### Add a new backbone/head/network `backbone` corresponds to a convolutional neural network model which performs feature extraction. `head` (which is an optional module) is mounted on top of `backbone` for further processing, which can be, for example, a MLP. `backbone` and `head` are basic building blocks for constructing a `SimpleNet()` (see `dassl/engine/trainer.py`) which serves as the primary model for a task. `network` contains custom neural network models, such as an image generator. To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding `registry`, i.e. `BACKBONE_REGISTRY` for `backbone`, `HEAD_REGISTRY` for `head` and `NETWORK_RESIGTRY` for `network`. Note that for a new `backbone`, we require the model to subclass `Backbone` as defined in `dassl/modeling/backbone/backbone.py` and specify the `self._out_features` attribute. We provide an example below for how to add a new `backbone`. ```python from dassl.modeling import Backbone, BACKBONE_REGISTRY class MyBackbone(Backbone): def __init__(self): super().__init__() # Create layers self.conv = ... self._out_features = 2048 def forward(self, x): # Extract and return features @BACKBONE_REGISTRY.register() def my_backbone(**kwargs): return MyBackbone() ``` Then, you can set `MODEL.BACKBONE.NAME` to `my_backbone` to use your own architecture. For more details, please refer to the source code in `dassl/modeling`. ### Add a dataset An example code structure is shown below. Make sure you subclass `DatasetBase` and register the dataset with `@DATASET_REGISTRY.register()`. All you need is to load `train_x`, `train_u` (optional), `val` (optional) and `test`, among which `train_u` and `val` could be `None` or simply ignored. Each of these variables contains a list of `Datum` objects. A `Datum` object (implemented [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/base_dataset.py#L12)) contains information for a single image, like `impath` (string) and `label` (int). ```python from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase @DATASET_REGISTRY.register() class NewDataset(DatasetBase): dataset_dir = '' def __init__(self, cfg): train_x = ... train_u = ... # optional, can be None val = ... # optional, can be None test = ... super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) ``` We suggest you take a look at the datasets code in some projects like [this](https://github.com/KaiyangZhou/CoOp), which is built on top of Dassl. ## Relevant Research We would like to share here our research relevant to Dassl. - [Domain Adaptive Ensemble Learning](https://arxiv.org/abs/2003.07325), TIP, 2021. - [MixStyle Neural Networks for Domain Generalization and Adaptation](https://arxiv.org/abs/2107.02053), arxiv preprint, 2021. - [Semi-Supervised Domain Generalization with Stochastic StyleMatch](https://arxiv.org/abs/2106.00592), arxiv preprint, 2021. - [Domain Generalization in Vision: A Survey](https://arxiv.org/abs/2103.02503), arxiv preprint, 2021. - [Domain Generalization with MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp), in ICLR 2021. - [Learning to Generate Novel Domains for Domain Generalization](https://arxiv.org/abs/2007.03304), in ECCV 2020. - [Deep Domain-Adversarial Image Generation for Domain Generalisation](https://arxiv.org/abs/2003.06054), in AAAI 2020. ## Citation If you find this code useful to your research, please give credit to the following paper ``` @article{zhou2020domain, title={Domain Adaptive Ensemble Learning}, author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, journal={IEEE Transactions on Image Processing (TIP)}, year={2021} } ``` ================================================ FILE: Dassl.ProGrad.pytorch/configs/README.md ================================================ The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings. ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml ================================================ INPUT: SIZE: (32, 32) PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "CIFARSTL" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml ================================================ INPUT: SIZE: (32, 32) PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] TRANSFORMS: ["normalize"] DATASET: NAME: "Digit5" MODEL: BACKBONE: NAME: "cnn_digit5_m3sda" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "DomainNet" MODEL: BACKBONE: NAME: "resnet101" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/mini_domainnet.yaml ================================================ INPUT: SIZE: (96, 96) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "miniDomainNet" MODEL: BACKBONE: NAME: "resnet18" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "Office31" MODEL: BACKBONE: NAME: "resnet50" HEAD: NAME: "mlp" HIDDEN_LAYERS: [256] DROPOUT: 0. ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office_home.yaml ================================================ INPUT: SIZE: (224, 224) DATASET: NAME: "OfficeHome" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "center_crop", "normalize"] DATASET: NAME: "VisDA17" MODEL: BACKBONE: NAME: "resnet101" TEST: PER_CLASS_RESULT: True ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["random_flip", "random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "CIFAR100C" CIFAR_C_TYPE: "fog" CIFAR_C_LEVEL: 5 MODEL: BACKBONE: NAME: "wide_resnet_16_4" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["random_flip", "random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "CIFAR10C" CIFAR_C_TYPE: "fog" CIFAR_C_LEVEL: 5 MODEL: BACKBONE: NAME: "wide_resnet_16_4" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "DigitSingle" MODEL: BACKBONE: NAME: "cnn_digitsingle" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "DigitsDG" MODEL: BACKBONE: NAME: "cnn_digitsdg" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/office_home_dg.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "OfficeHomeDG" MODEL: BACKBONE: NAME: "resnet18" PRETRAINED: True ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "PACS" MODEL: BACKBONE: NAME: "resnet18" PRETRAINED: True ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml ================================================ INPUT: SIZE: (224, 224) TRANSFORMS: ["random_flip", "random_translation", "normalize"] DATASET: NAME: "VLCS" MODEL: BACKBONE: NAME: "resnet18" PRETRAINED: True ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["random_flip", "random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] DATASET: NAME: "CIFAR10" NUM_LABELED: 4000 VAL_PERCENT: 0. MODEL: BACKBONE: NAME: "wide_resnet_28_2" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["random_flip", "random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] CROP_PADDING: 4 DATASET: NAME: "CIFAR100" NUM_LABELED: 10000 VAL_PERCENT: 0. MODEL: BACKBONE: NAME: "wide_resnet_28_2" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml ================================================ INPUT: SIZE: (96, 96) TRANSFORMS: ["random_flip", "random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] CROP_PADDING: 4 DATASET: NAME: "STL10" STL10_FOLD: 0 MODEL: BACKBONE: NAME: "wide_resnet_28_2" ================================================ FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml ================================================ INPUT: SIZE: (32, 32) TRANSFORMS: ["random_crop", "normalize"] PIXEL_MEAN: [0.5, 0.5, 0.5] PIXEL_STD: [0.5, 0.5, 0.5] CROP_PADDING: 4 DATASET: NAME: "SVHN" NUM_LABELED: 1000 VAL_PERCENT: 0. MODEL: BACKBONE: NAME: "wide_resnet_28_2" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml ================================================ DATALOADER: TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 256 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 64 TEST: BATCH_SIZE: 256 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [30] MAX_EPOCH: 30 LR_SCHEDULER: "cosine" TRAINER: DAEL: STRONG_TRANSFORMS: ["randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 4 TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 30 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 6 TEST: BATCH_SIZE: 30 OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 40 LR_SCHEDULER: "cosine" TRAINER: DAEL: STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/mini_domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 8 TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 192 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 64 TEST: BATCH_SIZE: 200 OPTIM: NAME: "sgd" LR: 0.005 MAX_EPOCH: 60 LR_SCHEDULER: "cosine" TRAINER: DAEL: STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml ================================================ DATALOADER: TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 256 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 64 TEST: BATCH_SIZE: 256 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [30] MAX_EPOCH: 30 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 4 TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 30 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 6 TEST: BATCH_SIZE: 30 OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 40 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 8 TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 192 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 64 TEST: BATCH_SIZE: 200 OPTIM: NAME: "sgd" LR: 0.005 MAX_EPOCH: 60 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/digit5.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 256 TEST: BATCH_SIZE: 256 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [30] MAX_EPOCH: 30 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 8 TRAIN_X: BATCH_SIZE: 128 TEST: BATCH_SIZE: 128 OPTIM: NAME: "sgd" LR: 0.005 MAX_EPOCH: 60 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/office31.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 32 OPTIM: NAME: "sgd" LR: 0.002 STEPSIZE: [20] MAX_EPOCH: 20 ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/visda17.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 32 OPTIM: NAME: "sgd" LR: 0.0001 STEPSIZE: [2] MAX_EPOCH: 2 TRAIN: PRINT_FREQ: 50 COUNT_ITER: "train_u" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/digits_dg.yaml ================================================ DATALOADER: TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 120 TEST: BATCH_SIZE: 100 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [20] MAX_EPOCH: 50 TRAINER: DAEL: STRONG_TRANSFORMS: ["randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/office_home_dg.yaml ================================================ DATALOADER: TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 30 TEST: BATCH_SIZE: 100 OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 40 LR_SCHEDULER: "cosine" TRAINER: DAEL: STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml ================================================ DATALOADER: TRAIN_X: SAMPLER: "RandomDomainSampler" BATCH_SIZE: 30 TEST: BATCH_SIZE: 100 OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 40 LR_SCHEDULER: "cosine" TRAINER: DAEL: STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml ================================================ INPUT: PIXEL_MEAN: [0., 0., 0.] PIXEL_STD: [1., 1., 1.] DATALOADER: TRAIN_X: BATCH_SIZE: 128 TEST: BATCH_SIZE: 128 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [20] MAX_EPOCH: 50 TRAINER: DDAIG: G_ARCH: "fcn_3x32_gctx" LMDA: 0.3 ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml ================================================ INPUT: PIXEL_MEAN: [0., 0., 0.] PIXEL_STD: [1., 1., 1.] DATALOADER: TRAIN_X: BATCH_SIZE: 16 TEST: BATCH_SIZE: 16 OPTIM: NAME: "sgd" LR: 0.0005 STEPSIZE: [20] MAX_EPOCH: 25 TRAINER: DDAIG: G_ARCH: "fcn_3x64_gctx" WARMUP: 3 LMDA: 0.3 ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml ================================================ INPUT: PIXEL_MEAN: [0., 0., 0.] PIXEL_STD: [1., 1., 1.] DATALOADER: TRAIN_X: BATCH_SIZE: 16 TEST: BATCH_SIZE: 16 OPTIM: NAME: "sgd" LR: 0.0005 STEPSIZE: [20] MAX_EPOCH: 25 TRAINER: DDAIG: G_ARCH: "fcn_3x64_gctx" WARMUP: 3 LMDA: 0.3 ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 128 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [20] MAX_EPOCH: 50 TRAIN: PRINT_FREQ: 20 ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml ================================================ DATALOADER: NUM_WORKERS: 8 TRAIN_X: BATCH_SIZE: 128 TEST: BATCH_SIZE: 128 OPTIM: NAME: "sgd" LR: 0.005 MAX_EPOCH: 60 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 64 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 OPTIM: NAME: "sgd" LR: 0.001 MAX_EPOCH: 50 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 64 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 OPTIM: NAME: "sgd" LR: 0.001 MAX_EPOCH: 50 LR_SCHEDULER: "cosine" ================================================ FILE: Dassl.ProGrad.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 64 TRAIN_U: SAME_AS_X: False BATCH_SIZE: 448 TEST: BATCH_SIZE: 500 OPTIM: NAME: "sgd" LR: 0.05 STEPSIZE: [4000] MAX_EPOCH: 4000 LR_SCHEDULER: "cosine" TRAIN: COUNT_ITER: "train_u" PRINT_FREQ: 10 TRAINER: FIXMATCH: STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"] ================================================ FILE: Dassl.ProGrad.pytorch/dassl/__init__.py ================================================ """ Dassl ------ PyTorch toolbox for domain adaptation and semi-supervised learning. URL: https://github.com/KaiyangZhou/Dassl.pytorch @article{zhou2020domain, title={Domain Adaptive Ensemble Learning}, author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, journal={arXiv preprint arXiv:2003.07325}, year={2020} } """ __version__ = "0.5.0" __author__ = "Kaiyang Zhou" __homepage__ = "https://kaiyangzhou.github.io/" ================================================ FILE: Dassl.ProGrad.pytorch/dassl/config/__init__.py ================================================ from .defaults import _C as cfg_default def get_cfg_default(): return cfg_default.clone() ================================================ FILE: Dassl.ProGrad.pytorch/dassl/config/defaults.py ================================================ from yacs.config import CfgNode as CN ########################### # Config definition ########################### _C = CN() _C.VERSION = 1 # Directory to save the output files (like log.txt and model weights) _C.OUTPUT_DIR = "./output" # Path to a directory where the files were saved previously _C.RESUME = "" # Set seed to negative value to randomize everything # Set seed to positive value to use a fixed seed _C.SEED = -1 _C.USE_CUDA = True # Print detailed information # E.g. trainer, dataset, and backbone _C.VERBOSE = True ########################### # Input ########################### _C.INPUT = CN() _C.INPUT.SIZE = (224, 224) # Mode of interpolation in resize functions _C.INPUT.INTERPOLATION = "bilinear" # For available choices please refer to transforms.py _C.INPUT.TRANSFORMS = () # If True, tfm_train and tfm_test will be None _C.INPUT.NO_TRANSFORM = False # Default mean and std come from ImageNet _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] # Padding for random crop _C.INPUT.CROP_PADDING = 4 # Cutout _C.INPUT.CUTOUT_N = 1 _C.INPUT.CUTOUT_LEN = 16 # Gaussian noise _C.INPUT.GN_MEAN = 0.0 _C.INPUT.GN_STD = 0.15 # RandomAugment _C.INPUT.RANDAUGMENT_N = 2 _C.INPUT.RANDAUGMENT_M = 10 # ColorJitter (brightness, contrast, saturation, hue) _C.INPUT.COLORJITTER_B = 0.4 _C.INPUT.COLORJITTER_C = 0.4 _C.INPUT.COLORJITTER_S = 0.4 _C.INPUT.COLORJITTER_H = 0.1 # Random gray scale's probability _C.INPUT.RGS_P = 0.2 # Gaussian blur _C.INPUT.GB_P = 0.5 # propability of applying this operation _C.INPUT.GB_K = 21 # kernel size (should be an odd number) ########################### # Dataset ########################### _C.DATASET = CN() # Directory where datasets are stored _C.DATASET.ROOT = "" _C.DATASET.NAME = "" # List of names of source domains _C.DATASET.SOURCE_DOMAINS = () # List of names of target domains _C.DATASET.TARGET_DOMAINS = () # Number of labeled instances in total # Useful for the semi-supervised learning _C.DATASET.NUM_LABELED = -1 # Number of images per class _C.DATASET.NUM_SHOTS = -1 # Percentage of validation data (only used for SSL datasets) # Set to 0 if do not want to use val data # Using val data for hyperparameter tuning was done in Oliver et al. 2018 _C.DATASET.VAL_PERCENT = 0.1 # Fold index for STL-10 dataset (normal range is 0 - 9) # Negative number means None _C.DATASET.STL10_FOLD = -1 # CIFAR-10/100-C's corruption type and intensity level _C.DATASET.CIFAR_C_TYPE = "" _C.DATASET.CIFAR_C_LEVEL = 1 # Use all data in the unlabeled data set (e.g. FixMatch) _C.DATASET.ALL_AS_UNLABELED = False ########################### # Dataloader ########################### _C.DATALOADER = CN() _C.DATALOADER.NUM_WORKERS = 4 # Apply transformations to an image K times (during training) _C.DATALOADER.K_TRANSFORMS = 1 # img0 denotes image tensor without augmentation # Useful for consistency learning _C.DATALOADER.RETURN_IMG0 = False # Setting for the train_x data-loader _C.DATALOADER.TRAIN_X = CN() _C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler" _C.DATALOADER.TRAIN_X.BATCH_SIZE = 32 # Parameter for RandomDomainSampler # 0 or -1 means sampling from all domains _C.DATALOADER.TRAIN_X.N_DOMAIN = 0 # Parameter of RandomClassSampler # Number of instances per class _C.DATALOADER.TRAIN_X.N_INS = 16 # Setting for the train_u data-loader _C.DATALOADER.TRAIN_U = CN() # Set to false if you want to have unique # data loader params for train_u _C.DATALOADER.TRAIN_U.SAME_AS_X = True _C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler" _C.DATALOADER.TRAIN_U.BATCH_SIZE = 32 _C.DATALOADER.TRAIN_U.N_DOMAIN = 0 _C.DATALOADER.TRAIN_U.N_INS = 16 # Setting for the test data-loader _C.DATALOADER.TEST = CN() _C.DATALOADER.TEST.SAMPLER = "SequentialSampler" _C.DATALOADER.TEST.BATCH_SIZE = 32 ########################### # Model ########################### _C.MODEL = CN() # Path to model weights (for initialization) _C.MODEL.INIT_WEIGHTS = "" _C.MODEL.BACKBONE = CN() _C.MODEL.BACKBONE.NAME = "" _C.MODEL.BACKBONE.PRETRAINED = True # Definition of embedding layers _C.MODEL.HEAD = CN() # If none, do not construct embedding layers, the # backbone's output will be passed to the classifier _C.MODEL.HEAD.NAME = "" # Structure of hidden layers (a list), e.g. [512, 512] # If undefined, no embedding layer will be constructed _C.MODEL.HEAD.HIDDEN_LAYERS = () _C.MODEL.HEAD.ACTIVATION = "relu" _C.MODEL.HEAD.BN = True _C.MODEL.HEAD.DROPOUT = 0.0 ########################### # Optimization ########################### _C.OPTIM = CN() _C.OPTIM.NAME = "adam" _C.OPTIM.LR = 0.0003 _C.OPTIM.WEIGHT_DECAY = 5e-4 _C.OPTIM.MOMENTUM = 0.9 _C.OPTIM.SGD_DAMPNING = 0 _C.OPTIM.SGD_NESTEROV = False _C.OPTIM.RMSPROP_ALPHA = 0.99 _C.OPTIM.ADAM_BETA1 = 0.9 _C.OPTIM.ADAM_BETA2 = 0.999 # STAGED_LR allows different layers to have # different lr, e.g. pre-trained base layers # can be assigned a smaller lr than the new # classification layer _C.OPTIM.STAGED_LR = False _C.OPTIM.NEW_LAYERS = () _C.OPTIM.BASE_LR_MULT = 0.1 # Learning rate scheduler _C.OPTIM.LR_SCHEDULER = "single_step" # -1 or 0 means the stepsize is equal to max_epoch _C.OPTIM.STEPSIZE = (-1, ) _C.OPTIM.GAMMA = 0.1 _C.OPTIM.MAX_EPOCH = 10 # Set WARMUP_EPOCH larger than 0 to activate warmup training _C.OPTIM.WARMUP_EPOCH = -1 # Either linear or constant _C.OPTIM.WARMUP_TYPE = "linear" # Constant learning rate when type=constant _C.OPTIM.WARMUP_CONS_LR = 1e-5 # Minimum learning rate when type=linear _C.OPTIM.WARMUP_MIN_LR = 1e-5 # Recount epoch for the next scheduler (last_epoch=-1) # Otherwise last_epoch=warmup_epoch _C.OPTIM.WARMUP_RECOUNT = True ########################### # Train ########################### _C.TRAIN = CN() # How often (epoch) to save model during training # Set to 0 or negative value to only save the last one _C.TRAIN.CHECKPOINT_FREQ = 0 # How often (batch) to print training information _C.TRAIN.PRINT_FREQ = 10 # Use 'train_x', 'train_u' or 'smaller_one' to count # the number of iterations in an epoch (for DA and SSL) _C.TRAIN.COUNT_ITER = "train_x" ########################### # Test ########################### _C.TEST = CN() _C.TEST.EVALUATOR = "Classification" _C.TEST.PER_CLASS_RESULT = False # Compute confusion matrix, which will be saved # to $OUTPUT_DIR/cmat.pt _C.TEST.COMPUTE_CMAT = False # If NO_TEST=True, no testing will be conducted _C.TEST.NO_TEST = False # Use test or val set for FINAL evaluation _C.TEST.SPLIT = "test" # Which model to test after training # Either last_step or best_val _C.TEST.FINAL_MODEL = "last_step" ########################### # Trainer specifics ########################### _C.TRAINER = CN() _C.TRAINER.NAME = "" # MCD _C.TRAINER.MCD = CN() _C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F # MME _C.TRAINER.MME = CN() _C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss # SelfEnsembling _C.TRAINER.SE = CN() _C.TRAINER.SE.EMA_ALPHA = 0.999 _C.TRAINER.SE.CONF_THRE = 0.95 _C.TRAINER.SE.RAMPUP = 300 # M3SDA _C.TRAINER.M3SDA = CN() _C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss _C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD # DAEL _C.TRAINER.DAEL = CN() _C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss _C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold _C.TRAINER.DAEL.STRONG_TRANSFORMS = () # CrossGrad _C.TRAINER.CG = CN() _C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients _C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients _C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss _C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss # DDAIG _C.TRAINER.DDAIG = CN() _C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture _C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight _C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values _C.TRAINER.DDAIG.CLAMP_MIN = -1.0 _C.TRAINER.DDAIG.CLAMP_MAX = 1.0 _C.TRAINER.DDAIG.WARMUP = 0 _C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses # EntMin _C.TRAINER.ENTMIN = CN() _C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss # Mean Teacher _C.TRAINER.MEANTEA = CN() _C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss _C.TRAINER.MEANTEA.EMA_ALPHA = 0.999 _C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight # MixMatch _C.TRAINER.MIXMATCH = CN() _C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss _C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability _C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75 _C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight # FixMatch _C.TRAINER.FIXMATCH = CN() _C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss _C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold _C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = () ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/__init__.py ================================================ from .data_manager import DataManager, DatasetWrapper ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/data_manager.py ================================================ import torch import torchvision.transforms as T from PIL import Image from torch.utils.data import Dataset as TorchDataset from dassl.utils import read_image from .datasets import build_dataset from .samplers import build_sampler from .transforms import build_transform INTERPOLATION_MODES = { "bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC, "nearest": Image.NEAREST, } def build_data_loader( cfg, sampler_type="SequentialSampler", data_source=None, batch_size=64, n_domain=0, n_ins=2, tfm=None, is_train=True, dataset_wrapper=None, ): # Build sampler sampler = build_sampler( sampler_type, cfg=cfg, data_source=data_source, batch_size=batch_size, n_domain=n_domain, n_ins=n_ins, ) if dataset_wrapper is None: dataset_wrapper = DatasetWrapper # Build data loader data_loader = torch.utils.data.DataLoader( dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train), batch_size=batch_size, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, drop_last=is_train and len(data_source) >= batch_size, pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), ) assert len(data_loader) > 0 return data_loader class DataManager: def __init__( self, cfg, custom_tfm_train=None, custom_tfm_test=None, dataset_wrapper=None ): # Load dataset dataset = build_dataset(cfg) # Build transform if custom_tfm_train is None: tfm_train = build_transform(cfg, is_train=True) else: print("* Using custom transform for training") tfm_train = custom_tfm_train if custom_tfm_test is None: tfm_test = build_transform(cfg, is_train=False) else: print("* Using custom transform for testing") tfm_test = custom_tfm_test # Build train_loader_x train_loader_x = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, data_source=dataset.train_x, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, n_ins=cfg.DATALOADER.TRAIN_X.N_INS, tfm=tfm_train, is_train=True, dataset_wrapper=dataset_wrapper, ) # Build train_loader_u train_loader_u = None if dataset.train_u: sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS if cfg.DATALOADER.TRAIN_U.SAME_AS_X: sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS train_loader_u = build_data_loader( cfg, sampler_type=sampler_type_, data_source=dataset.train_u, batch_size=batch_size_, n_domain=n_domain_, n_ins=n_ins_, tfm=tfm_train, is_train=True, dataset_wrapper=dataset_wrapper, ) # Build val_loader val_loader = None if dataset.val: val_loader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TEST.SAMPLER, data_source=dataset.val, batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, tfm=tfm_test, is_train=False, dataset_wrapper=dataset_wrapper, ) # Build test_loader test_loader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TEST.SAMPLER, data_source=dataset.test, batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, tfm=tfm_test, is_train=False, dataset_wrapper=dataset_wrapper, ) # Attributes self._num_classes = dataset.num_classes self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS) self._lab2cname = dataset.lab2cname # Dataset and data-loaders self.dataset = dataset self.train_loader_x = train_loader_x self.train_loader_u = train_loader_u self.val_loader = val_loader self.test_loader = test_loader if cfg.VERBOSE: self.show_dataset_summary(cfg) @property def num_classes(self): return self._num_classes @property def num_source_domains(self): return self._num_source_domains @property def lab2cname(self): return self._lab2cname def show_dataset_summary(self, cfg): print("***** Dataset statistics *****") print(" Dataset: {}".format(cfg.DATASET.NAME)) if cfg.DATASET.SOURCE_DOMAINS: print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS)) if cfg.DATASET.TARGET_DOMAINS: print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS)) print(" # classes: {:,}".format(self.num_classes)) print(" # train_x: {:,}".format(len(self.dataset.train_x))) if self.dataset.train_u: print(" # train_u: {:,}".format(len(self.dataset.train_u))) if self.dataset.val: print(" # val: {:,}".format(len(self.dataset.val))) print(" # test: {:,}".format(len(self.dataset.test))) class DatasetWrapper(TorchDataset): def __init__(self, cfg, data_source, transform=None, is_train=False): self.cfg = cfg self.data_source = data_source self.transform = transform # accept list (tuple) as input self.is_train = is_train # Augmenting an image K>1 times is only allowed during training self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1 self.return_img0 = cfg.DATALOADER.RETURN_IMG0 if self.k_tfm > 1 and transform is None: raise ValueError( "Cannot augment the image {} times " "because transform is None".format(self.k_tfm) ) # Build transform that doesn't apply any data augmentation interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] to_tensor = [] to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)] to_tensor += [T.ToTensor()] if "normalize" in cfg.INPUT.TRANSFORMS: normalize = T.Normalize( mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD ) to_tensor += [normalize] self.to_tensor = T.Compose(to_tensor) def __len__(self): return len(self.data_source) def __getitem__(self, idx): item = self.data_source[idx] output = { "label": item.label, "domain": item.domain, "impath": item.impath } img0 = read_image(item.impath) if self.transform is not None: if isinstance(self.transform, (list, tuple)): for i, tfm in enumerate(self.transform): img = self._transform_image(tfm, img0) keyname = "img" if (i + 1) > 1: keyname += str(i + 1) output[keyname] = img else: img = self._transform_image(self.transform, img0) output["img"] = img if self.return_img0: output["img0"] = self.to_tensor(img0) return output def _transform_image(self, tfm, img0): img_list = [] for k in range(self.k_tfm): img_list.append(tfm(img0)) img = img_list if len(img) == 1: img = img[0] return img ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py ================================================ from .build import DATASET_REGISTRY, build_dataset # isort:skip from .base_dataset import Datum, DatasetBase # isort:skip from .da import * from .dg import * from .ssl import * ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py ================================================ import os import random import os.path as osp import tarfile import zipfile from collections import defaultdict import gdown from dassl.utils import check_isfile class Datum: """Data instance which defines the basic attributes. Args: impath (str): image path. label (int): class label. domain (int): domain label. classname (str): class name. """ def __init__(self, impath="", label=0, domain=0, classname=""): assert isinstance(impath, str) assert check_isfile(impath) self._impath = impath self._label = label self._domain = domain self._classname = classname @property def impath(self): return self._impath @property def label(self): return self._label @property def domain(self): return self._domain @property def classname(self): return self._classname class DatasetBase: """A unified dataset class for 1) domain adaptation 2) domain generalization 3) semi-supervised learning """ dataset_dir = "" # the directory where the dataset is stored domains = [] # string names of all domains def __init__(self, train_x=None, train_u=None, val=None, test=None): self._train_x = train_x # labeled training data self._train_u = train_u # unlabeled training data (optional) self._val = val # validation data (optional) self._test = test # test data self._num_classes = self.get_num_classes(train_x) self._lab2cname, self._classnames = self.get_lab2cname(train_x) @property def train_x(self): return self._train_x @property def train_u(self): return self._train_u @property def val(self): return self._val @property def test(self): return self._test @property def lab2cname(self): return self._lab2cname @property def classnames(self): return self._classnames @property def num_classes(self): return self._num_classes def get_num_classes(self, data_source): """Count number of classes. Args: data_source (list): a list of Datum objects. """ label_set = set() for item in data_source: label_set.add(item.label) return max(label_set) + 1 def get_lab2cname(self, data_source): """Get a label-to-classname mapping (dict). Args: data_source (list): a list of Datum objects. """ container = set() for item in data_source: container.add((item.label, item.classname)) mapping = {label: classname for label, classname in container} labels = list(mapping.keys()) labels.sort() classnames = [mapping[label] for label in labels] return mapping, classnames def check_input_domains(self, source_domains, target_domains): self.is_input_domain_valid(source_domains) self.is_input_domain_valid(target_domains) def is_input_domain_valid(self, input_domains): for domain in input_domains: if domain not in self.domains: raise ValueError( "Input domain must belong to {}, " "but got [{}]".format(self.domains, domain) ) def download_data(self, url, dst, from_gdrive=True): if not osp.exists(osp.dirname(dst)): os.makedirs(osp.dirname(dst)) if from_gdrive: gdown.download(url, dst, quiet=False) else: raise NotImplementedError print("Extracting file ...") try: tar = tarfile.open(dst) tar.extractall(path=osp.dirname(dst)) tar.close() except: zip_ref = zipfile.ZipFile(dst, "r") zip_ref.extractall(osp.dirname(dst)) zip_ref.close() print("File extracted to {}".format(osp.dirname(dst))) def generate_fewshot_dataset( self, *data_sources, num_shots=-1, repeat=False ): """Generate a few-shot dataset (typically for the training set). This function is useful when one wants to evaluate a model in a few-shot learning setting where each class only contains a few number of images. Args: data_sources: each individual is a list containing Datum objects. num_shots (int): number of instances per class to sample. repeat (bool): repeat images if needed (default: False). """ if num_shots < 1: if len(data_sources) == 1: return data_sources[0] return data_sources print(f"Creating a {num_shots}-shot dataset") output = [] for data_source in data_sources: tracker = self.split_dataset_by_label(data_source) dataset = [] for label, items in tracker.items(): if len(items) >= num_shots: sampled_items = random.sample(items, num_shots) else: if repeat: sampled_items = random.choices(items, k=num_shots) else: sampled_items = items dataset.extend(sampled_items) output.append(dataset) if len(output) == 1: return output[0] return output def split_dataset_by_label(self, data_source): """Split a dataset, i.e. a list of Datum objects, into class-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.label].append(item) return output def split_dataset_by_domain(self, data_source): """Split a dataset, i.e. a list of Datum objects, into domain-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.domain].append(item) return output ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/build.py ================================================ from dassl.utils import Registry, check_availability DATASET_REGISTRY = Registry("DATASET") def build_dataset(cfg): avai_datasets = DATASET_REGISTRY.registered_names() check_availability(cfg.DATASET.NAME, avai_datasets) if cfg.VERBOSE: print("Loading dataset: {}".format(cfg.DATASET.NAME)) return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py ================================================ from .digit5 import Digit5 from .visda17 import VisDA17 from .cifarstl import CIFARSTL from .office31 import Office31 from .domainnet import DomainNet from .office_home import OfficeHome from .mini_domainnet import miniDomainNet ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py ================================================ import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class CIFARSTL(DatasetBase): """CIFAR-10 and STL-10. CIFAR-10: - 60,000 32x32 colour images. - 10 classes, with 6,000 images per class. - 50,000 training images and 10,000 test images. - URL: https://www.cs.toronto.edu/~kriz/cifar.html. STL-10: - 10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck. - Images are 96x96 pixels, color. - 500 training images (10 pre-defined folds), 800 test images per class. - URL: https://cs.stanford.edu/~acoates/stl10/. Reference: - Krizhevsky. Learning Multiple Layers of Features from Tiny Images. Tech report. - Coates et al. An Analysis of Single Layer Networks in Unsupervised Feature Learning. AISTATS 2011. """ dataset_dir = "cifar_stl" domains = ["cifar", "stl"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): data_dir = osp.join(self.dataset_dir, dname, split) class_names = listdir_nohidden(data_dir) for class_name in class_names: class_dir = osp.join(data_dir, class_name) imnames = listdir_nohidden(class_dir) label = int(class_name.split("_")[0]) for imname in imnames: impath = osp.join(class_dir, imname) item = Datum(impath=impath, label=label, domain=domain) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py ================================================ import random import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase # Folder names for train and test sets MNIST = {"train": "train_images", "test": "test_images"} MNIST_M = {"train": "train_images", "test": "test_images"} SVHN = {"train": "train_images", "test": "test_images"} SYN = {"train": "train_images", "test": "test_images"} USPS = {"train": "train_images", "test": "test_images"} def read_image_list(im_dir, n_max=None, n_repeat=None): items = [] for imname in listdir_nohidden(im_dir): imname_noext = osp.splitext(imname)[0] label = int(imname_noext.split("_")[1]) impath = osp.join(im_dir, imname) items.append((impath, label)) if n_max is not None: items = random.sample(items, n_max) if n_repeat is not None: items *= n_repeat return items def load_mnist(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_mnist_m(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST_M[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_svhn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SVHN[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_syn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SYN[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_usps(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, USPS[split]) n_repeat = 3 if split == "train" else None return read_image_list(data_dir, n_repeat=n_repeat) @DATASET_REGISTRY.register() class Digit5(DatasetBase): """Five digit datasets. It contains: - MNIST: hand-written digits. - MNIST-M: variant of MNIST with blended background. - SVHN: street view house number. - SYN: synthetic digits. - USPS: hand-written digits, slightly different from MNIST. For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from the training set and 9,000 images from the test set. For USPS which has only 9,298 images in total, we use the entire dataset but replicate its training set for 3 times so as to match the training set size of other domains. Reference: - Lecun et al. Gradient-based learning applied to document recognition. IEEE 1998. - Ganin et al. Domain-adversarial training of neural networks. JMLR 2016. - Netzer et al. Reading digits in natural images with unsupervised feature learning. NIPS-W 2011. """ dataset_dir = "digit5" domains = ["mnist", "mnist_m", "svhn", "syn", "usps"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): func = "load_" + dname domain_dir = osp.join(self.dataset_dir, dname) items_d = eval(func)(domain_dir, split=split) for impath, label in items_d: item = Datum( impath=impath, label=label, domain=domain, classname=str(label) ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py ================================================ import os.path as osp from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class DomainNet(DatasetBase): """DomainNet. Statistics: - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw, Real, Sketch. - Around 0.6M images. - 345 categories. - URL: http://ai.bu.edu/M3SDA/. Special note: the t-shirt class (327) is missing in painting_train.txt. Reference: - Peng et al. Moment Matching for Multi-Source Domain Adaptation. ICCV 2019. """ dataset_dir = "domainnet" domains = [ "clipart", "infograph", "painting", "quickdraw", "real", "sketch" ] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.split_dir = osp.join(self.dataset_dir, "splits") self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): filename = dname + "_" + split + ".txt" split_file = osp.join(self.split_dir, filename) with open(split_file, "r") as f: lines = f.readlines() for line in lines: line = line.strip() impath, label = line.split(" ") classname = impath.split("/")[1] impath = osp.join(self.dataset_dir, impath) label = int(label) item = Datum( impath=impath, label=label, domain=domain, classname=classname ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py ================================================ import os.path as osp from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class miniDomainNet(DatasetBase): """A subset of DomainNet. Reference: - Peng et al. Moment Matching for Multi-Source Domain Adaptation. ICCV 2019. - Zhou et al. Domain Adaptive Ensemble Learning. """ dataset_dir = "domainnet" domains = ["clipart", "painting", "real", "sketch"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.split_dir = osp.join(self.dataset_dir, "splits_mini") self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): filename = dname + "_" + split + ".txt" split_file = osp.join(self.split_dir, filename) with open(split_file, "r") as f: lines = f.readlines() for line in lines: line = line.strip() impath, label = line.split(" ") classname = impath.split("/")[1] impath = osp.join(self.dataset_dir, impath) label = int(label) item = Datum( impath=impath, label=label, domain=domain, classname=classname ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py ================================================ import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class Office31(DatasetBase): """Office-31. Statistics: - 4,110 images. - 31 classes related to office objects. - 3 domains: Amazon, Webcam, Dslr. - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/. Reference: - Saenko et al. Adapting visual category models to new domains. ECCV 2010. """ dataset_dir = "office31" domains = ["amazon", "webcam", "dslr"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) test = self._read_data(cfg.DATASET.TARGET_DOMAINS) super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains): items = [] for domain, dname in enumerate(input_domains): domain_dir = osp.join(self.dataset_dir, dname) class_names = listdir_nohidden(domain_dir) class_names.sort() for label, class_name in enumerate(class_names): class_path = osp.join(domain_dir, class_name) imnames = listdir_nohidden(class_path) for imname in imnames: impath = osp.join(class_path, imname) item = Datum( impath=impath, label=label, domain=domain, classname=class_name ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py ================================================ import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class OfficeHome(DatasetBase): """Office-Home. Statistics: - Around 15,500 images. - 65 classes related to office and home objects. - 4 domains: Art, Clipart, Product, Real World. - URL: http://hemanthdv.org/OfficeHome-Dataset/. Reference: - Venkateswara et al. Deep Hashing Network for Unsupervised Domain Adaptation. CVPR 2017. """ dataset_dir = "office_home" domains = ["art", "clipart", "product", "real_world"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) test = self._read_data(cfg.DATASET.TARGET_DOMAINS) super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains): items = [] for domain, dname in enumerate(input_domains): domain_dir = osp.join(self.dataset_dir, dname) class_names = listdir_nohidden(domain_dir) class_names.sort() for label, class_name in enumerate(class_names): class_path = osp.join(domain_dir, class_name) imnames = listdir_nohidden(class_path) for imname in imnames: impath = osp.join(class_path, imname) item = Datum( impath=impath, label=label, domain=domain, classname=class_name.lower(), ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py ================================================ import os.path as osp from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class VisDA17(DatasetBase): """VisDA17. Focusing on simulation-to-reality domain shift. URL: http://ai.bu.edu/visda-2017/. Reference: - Peng et al. VisDA: The Visual Domain Adaptation Challenge. ArXiv 2017. """ dataset_dir = "visda17" domains = ["synthetic", "real"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data("synthetic") train_u = self._read_data("real") test = self._read_data("real") super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, dname): filedir = "train" if dname == "synthetic" else "validation" image_list = osp.join(self.dataset_dir, filedir, "image_list.txt") items = [] # There is only one source domain domain = 0 with open(image_list, "r") as f: lines = f.readlines() for line in lines: line = line.strip() impath, label = line.split(" ") classname = impath.split("/")[0] impath = osp.join(self.dataset_dir, filedir, impath) label = int(label) item = Datum( impath=impath, label=label, domain=domain, classname=classname ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py ================================================ from .pacs import PACS from .vlcs import VLCS from .cifar_c import CIFAR10C, CIFAR100C from .digits_dg import DigitsDG from .digit_single import DigitSingle from .office_home_dg import OfficeHomeDG ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py ================================================ import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase AVAI_C_TYPES = [ "brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise", "snow", "spatter", "speckle_noise", "zoom_blur", ] @DATASET_REGISTRY.register() class CIFAR10C(DatasetBase): """CIFAR-10 -> CIFAR-10-C. Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o Statistics: - 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10 - 10 categories Reference: - Hendrycks et al. Benchmarking neural network robustness to common corruptions and perturbations. ICLR 2019. """ dataset_dir = "" domains = ["cifar10", "cifar10_c"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = root self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) source_domain = cfg.DATASET.SOURCE_DOMAINS[0] target_domain = cfg.DATASET.TARGET_DOMAINS[0] assert source_domain == self.domains[0] assert target_domain == self.domains[1] c_type = cfg.DATASET.CIFAR_C_TYPE c_level = cfg.DATASET.CIFAR_C_LEVEL if not c_type: raise ValueError( "Please specify DATASET.CIFAR_C_TYPE in the config file" ) assert ( c_type in AVAI_C_TYPES ), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"' assert 1 <= c_level <= 5 train_dir = osp.join(self.dataset_dir, source_domain, "train") test_dir = osp.join( self.dataset_dir, target_domain, c_type, str(c_level) ) if not osp.exists(test_dir): raise ValueError train = self._read_data(train_dir) test = self._read_data(test_dir) super().__init__(train_x=train, test=test) def _read_data(self, data_dir): class_names = listdir_nohidden(data_dir) class_names.sort() items = [] for label, class_name in enumerate(class_names): class_dir = osp.join(data_dir, class_name) imnames = listdir_nohidden(class_dir) for imname in imnames: impath = osp.join(class_dir, imname) item = Datum(impath=impath, label=label, domain=0) items.append(item) return items @DATASET_REGISTRY.register() class CIFAR100C(CIFAR10C): """CIFAR-100 -> CIFAR-100-C. Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o Statistics: - 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100 - 10 categories Reference: - Hendrycks et al. Benchmarking neural network robustness to common corruptions and perturbations. ICLR 2019. """ dataset_dir = "" domains = ["cifar100", "cifar100_c"] def __init__(self, cfg): super().__init__(cfg) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py ================================================ import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase # Folder names for train and test sets MNIST = {"train": "train_images", "test": "test_images"} MNIST_M = {"train": "train_images", "test": "test_images"} SVHN = {"train": "train_images", "test": "test_images"} SYN = {"train": "train_images", "test": "test_images"} USPS = {"train": "train_images", "test": "test_images"} def read_image_list(im_dir, n_max=None, n_repeat=None): items = [] for imname in listdir_nohidden(im_dir): imname_noext = osp.splitext(imname)[0] label = int(imname_noext.split("_")[1]) impath = osp.join(im_dir, imname) items.append((impath, label)) if n_max is not None: # Note that the sampling process is NOT random, # which follows that in Volpi et al. NIPS'18. items = items[:n_max] if n_repeat is not None: items *= n_repeat return items def load_mnist(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_mnist_m(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST_M[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_svhn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SVHN[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_syn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SYN[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_usps(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, USPS[split]) return read_image_list(data_dir) @DATASET_REGISTRY.register() class DigitSingle(DatasetBase): """Digit recognition datasets for single-source domain generalization. There are five digit datasets: - MNIST: hand-written digits. - MNIST-M: variant of MNIST with blended background. - SVHN: street view house number. - SYN: synthetic digits. - USPS: hand-written digits, slightly different from MNIST. Protocol: Volpi et al. train a model using 10,000 images from MNIST and evaluate the model on the test split of the other four datasets. However, the code does not restrict you to only use MNIST as the source dataset. Instead, you can use any dataset as the source. But note that only 10,000 images will be sampled from the source dataset for training. Reference: - Lecun et al. Gradient-based learning applied to document recognition. IEEE 1998. - Ganin et al. Domain-adversarial training of neural networks. JMLR 2016. - Netzer et al. Reading digits in natural images with unsupervised feature learning. NIPS-W 2011. - Volpi et al. Generalizing to Unseen Domains via Adversarial Data Augmentation. NIPS 2018. """ # Reuse the digit-5 folder instead of creating a new folder dataset_dir = "digit5" domains = ["mnist", "mnist_m", "svhn", "syn", "usps"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train, val=val, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): func = "load_" + dname domain_dir = osp.join(self.dataset_dir, dname) items_d = eval(func)(domain_dir, split=split) for impath, label in items_d: item = Datum(impath=impath, label=label, domain=domain) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py ================================================ import glob import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class DigitsDG(DatasetBase): """Digits-DG. It contains 4 digit datasets: - MNIST: hand-written digits. - MNIST-M: variant of MNIST with blended background. - SVHN: street view house number. - SYN: synthetic digits. Reference: - Lecun et al. Gradient-based learning applied to document recognition. IEEE 1998. - Ganin et al. Domain-adversarial training of neural networks. JMLR 2016. - Netzer et al. Reading digits in natural images with unsupervised feature learning. NIPS-W 2011. - Zhou et al. Deep Domain-Adversarial Image Generation for Domain Generalisation. AAAI 2020. """ dataset_dir = "digits_dg" domains = ["mnist", "mnist_m", "svhn", "syn"] data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7" def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) if not osp.exists(self.dataset_dir): dst = osp.join(root, "digits_dg.zip") self.download_data(self.data_url, dst, from_gdrive=True) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = self.read_data( self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" ) val = self.read_data( self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" ) test = self.read_data( self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" ) super().__init__(train_x=train, val=val, test=test) @staticmethod def read_data(dataset_dir, input_domains, split): def _load_data_from_directory(directory): folders = listdir_nohidden(directory) folders.sort() items_ = [] for label, folder in enumerate(folders): impaths = glob.glob(osp.join(directory, folder, "*.jpg")) for impath in impaths: items_.append((impath, label)) return items_ items = [] for domain, dname in enumerate(input_domains): if split == "all": train_dir = osp.join(dataset_dir, dname, "train") impath_label_list = _load_data_from_directory(train_dir) val_dir = osp.join(dataset_dir, dname, "val") impath_label_list += _load_data_from_directory(val_dir) else: split_dir = osp.join(dataset_dir, dname, split) impath_label_list = _load_data_from_directory(split_dir) for impath, label in impath_label_list: class_name = impath.split("/")[-2].lower() item = Datum( impath=impath, label=label, domain=domain, classname=class_name ) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py ================================================ import os.path as osp from ..build import DATASET_REGISTRY from .digits_dg import DigitsDG from ..base_dataset import DatasetBase @DATASET_REGISTRY.register() class OfficeHomeDG(DatasetBase): """Office-Home. Statistics: - Around 15,500 images. - 65 classes related to office and home objects. - 4 domains: Art, Clipart, Product, Real World. - URL: http://hemanthdv.org/OfficeHome-Dataset/. Reference: - Venkateswara et al. Deep Hashing Network for Unsupervised Domain Adaptation. CVPR 2017. """ dataset_dir = "office_home_dg" domains = ["art", "clipart", "product", "real_world"] data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa" def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) if not osp.exists(self.dataset_dir): dst = osp.join(root, "office_home_dg.zip") self.download_data(self.data_url, dst, from_gdrive=True) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = DigitsDG.read_data( self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" ) val = DigitsDG.read_data( self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" ) test = DigitsDG.read_data( self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" ) super().__init__(train_x=train, val=val, test=test) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py ================================================ import os.path as osp from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class PACS(DatasetBase): """PACS. Statistics: - 4 domains: Photo (1,670), Art (2,048), Cartoon (2,344), Sketch (3,929). - 7 categories: dog, elephant, giraffe, guitar, horse, house and person. Reference: - Li et al. Deeper, broader and artier domain generalization. ICCV 2017. """ dataset_dir = "pacs" domains = ["art_painting", "cartoon", "photo", "sketch"] data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE" # the following images contain errors and should be ignored _error_paths = ["sketch/dog/n02103406_4068-1.png"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.image_dir = osp.join(self.dataset_dir, "images") self.split_dir = osp.join(self.dataset_dir, "splits") if not osp.exists(self.dataset_dir): dst = osp.join(root, "pacs.zip") self.download_data(self.data_url, dst, from_gdrive=True) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all") super().__init__(train_x=train, val=val, test=test) def _read_data(self, input_domains, split): items = [] for domain, dname in enumerate(input_domains): if split == "all": file_train = osp.join( self.split_dir, dname + "_train_kfold.txt" ) impath_label_list = self._read_split_pacs(file_train) file_val = osp.join( self.split_dir, dname + "_crossval_kfold.txt" ) impath_label_list += self._read_split_pacs(file_val) else: file = osp.join( self.split_dir, dname + "_" + split + "_kfold.txt" ) impath_label_list = self._read_split_pacs(file) for impath, label in impath_label_list: classname = impath.split("/")[-2] item = Datum( impath=impath, label=label, domain=domain, classname=classname ) items.append(item) return items def _read_split_pacs(self, split_file): items = [] with open(split_file, "r") as f: lines = f.readlines() for line in lines: line = line.strip() impath, label = line.split(" ") if impath in self._error_paths: continue impath = osp.join(self.image_dir, impath) label = int(label) - 1 items.append((impath, label)) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py ================================================ import glob import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class VLCS(DatasetBase): """VLCS. Statistics: - 4 domains: CALTECH, LABELME, PASCAL, SUN - 5 categories: bird, car, chair, dog, and person. Reference: - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011. """ dataset_dir = "VLCS" domains = ["caltech", "labelme", "pascal", "sun"] data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd" def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) if not osp.exists(self.dataset_dir): dst = osp.join(root, "vlcs.zip") self.download_data(self.data_url, dst, from_gdrive=True) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test") super().__init__(train_x=train, val=val, test=test) def _read_data(self, input_domains, split): items = [] for domain, dname in enumerate(input_domains): dname = dname.upper() path = osp.join(self.dataset_dir, dname, split) folders = listdir_nohidden(path) folders.sort() for label, folder in enumerate(folders): impaths = glob.glob(osp.join(path, folder, "*.jpg")) for impath in impaths: item = Datum(impath=impath, label=label, domain=domain) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/__init__.py ================================================ from .svhn import SVHN from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py ================================================ import math import random import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class CIFAR10(DatasetBase): """CIFAR10 for SSL. Reference: - Krizhevsky. Learning Multiple Layers of Features from Tiny Images. Tech report. """ dataset_dir = "cifar10" def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) train_dir = osp.join(self.dataset_dir, "train") test_dir = osp.join(self.dataset_dir, "test") assert cfg.DATASET.NUM_LABELED > 0 train_x, train_u, val = self._read_data_train( train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT ) test = self._read_data_test(test_dir) if cfg.DATASET.ALL_AS_UNLABELED: train_u = train_u + train_x if len(val) == 0: val = None super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) def _read_data_train(self, data_dir, num_labeled, val_percent): class_names = listdir_nohidden(data_dir) class_names.sort() num_labeled_per_class = num_labeled / len(class_names) items_x, items_u, items_v = [], [], [] for label, class_name in enumerate(class_names): class_dir = osp.join(data_dir, class_name) imnames = listdir_nohidden(class_dir) # Split into train and val following Oliver et al. 2018 # Set cfg.DATASET.VAL_PERCENT to 0 to not use val data num_val = math.floor(len(imnames) * val_percent) imnames_train = imnames[num_val:] imnames_val = imnames[:num_val] # Note we do shuffle after split random.shuffle(imnames_train) for i, imname in enumerate(imnames_train): impath = osp.join(class_dir, imname) item = Datum(impath=impath, label=label) if (i + 1) <= num_labeled_per_class: items_x.append(item) else: items_u.append(item) for imname in imnames_val: impath = osp.join(class_dir, imname) item = Datum(impath=impath, label=label) items_v.append(item) return items_x, items_u, items_v def _read_data_test(self, data_dir): class_names = listdir_nohidden(data_dir) class_names.sort() items = [] for label, class_name in enumerate(class_names): class_dir = osp.join(data_dir, class_name) imnames = listdir_nohidden(class_dir) for imname in imnames: impath = osp.join(class_dir, imname) item = Datum(impath=impath, label=label) items.append(item) return items @DATASET_REGISTRY.register() class CIFAR100(CIFAR10): """CIFAR100 for SSL. Reference: - Krizhevsky. Learning Multiple Layers of Features from Tiny Images. Tech report. """ dataset_dir = "cifar100" def __init__(self, cfg): super().__init__(cfg) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py ================================================ import numpy as np import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase @DATASET_REGISTRY.register() class STL10(DatasetBase): """STL-10 dataset. Description: - 10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck. - Images are 96x96 pixels, color. - 500 training images per class, 800 test images per class. - 100,000 unlabeled images for unsupervised learning. Reference: - Coates et al. An Analysis of Single Layer Networks in Unsupervised Feature Learning. AISTATS 2011. """ dataset_dir = "stl10" def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) train_dir = osp.join(self.dataset_dir, "train") test_dir = osp.join(self.dataset_dir, "test") unlabeled_dir = osp.join(self.dataset_dir, "unlabeled") fold_file = osp.join( self.dataset_dir, "stl10_binary", "fold_indices.txt" ) # Only use the first five splits assert 0 <= cfg.DATASET.STL10_FOLD <= 4 train_x = self._read_data_train( train_dir, cfg.DATASET.STL10_FOLD, fold_file ) train_u = self._read_data_all(unlabeled_dir) test = self._read_data_all(test_dir) if cfg.DATASET.ALL_AS_UNLABELED: train_u = train_u + train_x super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data_train(self, data_dir, fold, fold_file): imnames = listdir_nohidden(data_dir) imnames.sort() items = [] list_idx = list(range(len(imnames))) if fold >= 0: with open(fold_file, "r") as f: str_idx = f.read().splitlines()[fold] list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ") for i in list_idx: imname = imnames[i] impath = osp.join(data_dir, imname) label = osp.splitext(imname)[0].split("_")[1] label = int(label) item = Datum(impath=impath, label=label) items.append(item) return items def _read_data_all(self, data_dir): imnames = listdir_nohidden(data_dir) items = [] for imname in imnames: impath = osp.join(data_dir, imname) label = osp.splitext(imname)[0].split("_")[1] if label == "none": label = -1 else: label = int(label) item = Datum(impath=impath, label=label) items.append(item) return items ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py ================================================ from .cifar import CIFAR10 from ..build import DATASET_REGISTRY @DATASET_REGISTRY.register() class SVHN(CIFAR10): """SVHN for SSL. Reference: - Netzer et al. Reading Digits in Natural Images with Unsupervised Feature Learning. NIPS-W 2011. """ dataset_dir = "svhn" def __init__(self, cfg): super().__init__(cfg) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/samplers.py ================================================ import copy import numpy as np import random from collections import defaultdict from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler class RandomDomainSampler(Sampler): """Randomly samples N domains each with K images to form a minibatch of size N*K. Args: data_source (list): list of Datums. batch_size (int): batch size. n_domain (int): number of domains to sample in a minibatch. """ def __init__(self, data_source, batch_size, n_domain): self.data_source = data_source # Keep track of image indices for each domain self.domain_dict = defaultdict(list) for i, item in enumerate(data_source): self.domain_dict[item.domain].append(i) self.domains = list(self.domain_dict.keys()) # Make sure each domain has equal number of images if n_domain is None or n_domain <= 0: n_domain = len(self.domains) assert batch_size % n_domain == 0 self.n_img_per_domain = batch_size // n_domain self.batch_size = batch_size # n_domain denotes number of domains sampled in a minibatch self.n_domain = n_domain self.length = len(list(self.__iter__())) def __iter__(self): domain_dict = copy.deepcopy(self.domain_dict) final_idxs = [] stop_sampling = False while not stop_sampling: selected_domains = random.sample(self.domains, self.n_domain) for domain in selected_domains: idxs = domain_dict[domain] selected_idxs = random.sample(idxs, self.n_img_per_domain) final_idxs.extend(selected_idxs) for idx in selected_idxs: domain_dict[domain].remove(idx) remaining = len(domain_dict[domain]) if remaining < self.n_img_per_domain: stop_sampling = True return iter(final_idxs) def __len__(self): return self.length class SeqDomainSampler(Sampler): """Sequential domain sampler, which randomly samples K images from each domain to form a minibatch. Args: data_source (list): list of Datums. batch_size (int): batch size. """ def __init__(self, data_source, batch_size): self.data_source = data_source # Keep track of image indices for each domain self.domain_dict = defaultdict(list) for i, item in enumerate(data_source): self.domain_dict[item.domain].append(i) self.domains = list(self.domain_dict.keys()) self.domains.sort() # Make sure each domain has equal number of images n_domain = len(self.domains) assert batch_size % n_domain == 0 self.n_img_per_domain = batch_size // n_domain self.batch_size = batch_size # n_domain denotes number of domains sampled in a minibatch self.n_domain = n_domain self.length = len(list(self.__iter__())) def __iter__(self): domain_dict = copy.deepcopy(self.domain_dict) final_idxs = [] stop_sampling = False while not stop_sampling: for domain in self.domains: idxs = domain_dict[domain] selected_idxs = random.sample(idxs, self.n_img_per_domain) final_idxs.extend(selected_idxs) for idx in selected_idxs: domain_dict[domain].remove(idx) remaining = len(domain_dict[domain]) if remaining < self.n_img_per_domain: stop_sampling = True return iter(final_idxs) def __len__(self): return self.length class RandomClassSampler(Sampler): """Randomly samples N classes each with K instances to form a minibatch of size N*K. Modified from https://github.com/KaiyangZhou/deep-person-reid. Args: data_source (list): list of Datums. batch_size (int): batch size. n_ins (int): number of instances per class to sample in a minibatch. """ def __init__(self, data_source, batch_size, n_ins): if batch_size < n_ins: raise ValueError( "batch_size={} must be no less " "than n_ins={}".format(batch_size, n_ins) ) self.data_source = data_source self.batch_size = batch_size self.n_ins = n_ins self.ncls_per_batch = self.batch_size // self.n_ins self.index_dic = defaultdict(list) for index, item in enumerate(data_source): self.index_dic[item.label].append(index) self.labels = list(self.index_dic.keys()) assert len(self.labels) >= self.ncls_per_batch # estimate number of images in an epoch self.length = len(list(self.__iter__())) def __iter__(self): batch_idxs_dict = defaultdict(list) for label in self.labels: idxs = copy.deepcopy(self.index_dic[label]) if len(idxs) < self.n_ins: idxs = np.random.choice(idxs, size=self.n_ins, replace=True) random.shuffle(idxs) batch_idxs = [] for idx in idxs: batch_idxs.append(idx) if len(batch_idxs) == self.n_ins: batch_idxs_dict[label].append(batch_idxs) batch_idxs = [] avai_labels = copy.deepcopy(self.labels) final_idxs = [] while len(avai_labels) >= self.ncls_per_batch: selected_labels = random.sample(avai_labels, self.ncls_per_batch) for label in selected_labels: batch_idxs = batch_idxs_dict[label].pop(0) final_idxs.extend(batch_idxs) if len(batch_idxs_dict[label]) == 0: avai_labels.remove(label) return iter(final_idxs) def __len__(self): return self.length def build_sampler( sampler_type, cfg=None, data_source=None, batch_size=32, n_domain=0, n_ins=16 ): if sampler_type == "RandomSampler": return RandomSampler(data_source) elif sampler_type == "SequentialSampler": return SequentialSampler(data_source) elif sampler_type == "RandomDomainSampler": return RandomDomainSampler(data_source, batch_size, n_domain) elif sampler_type == "SeqDomainSampler": return SeqDomainSampler(data_source, batch_size) elif sampler_type == "RandomClassSampler": return RandomClassSampler(data_source, batch_size, n_ins) else: raise ValueError("Unknown sampler type: {}".format(sampler_type)) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py ================================================ from .transforms import build_transform ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py ================================================ """ Source: https://github.com/DeepVoltaire/AutoAugment """ import numpy as np import random from PIL import Image, ImageOps, ImageEnhance class ImageNetPolicy: """Randomly choose one of the best 24 Sub-policies on ImageNet. Example: >>> policy = ImageNetPolicy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> ImageNetPolicy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment ImageNet Policy" class CIFAR10Policy: """Randomly choose one of the best 25 Sub-policies on CIFAR10. Example: >>> policy = CIFAR10Policy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> CIFAR10Policy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment CIFAR10 Policy" class SVHNPolicy: """Randomly choose one of the best 25 Sub-policies on SVHN. Example: >>> policy = SVHNPolicy() >>> transformed = policy(image) Example as a PyTorch Transform: >>> transform=transforms.Compose([ >>> transforms.Resize(256), >>> SVHNPolicy(), >>> transforms.ToTensor()]) """ def __init__(self, fillcolor=(128, 128, 128)): self.policies = [ SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor), ] def __call__(self, img): policy_idx = random.randint(0, len(self.policies) - 1) return self.policies[policy_idx](img) def __repr__(self): return "AutoAugment SVHN Policy" class SubPolicy(object): def __init__( self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128), ): ranges = { "shearX": np.linspace(0, 0.3, 10), "shearY": np.linspace(0, 0.3, 10), "translateX": np.linspace(0, 150 / 331, 10), "translateY": np.linspace(0, 150 / 331, 10), "rotate": np.linspace(0, 30, 10), "color": np.linspace(0.0, 0.9, 10), "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), "solarize": np.linspace(256, 0, 10), "contrast": np.linspace(0.0, 0.9, 10), "sharpness": np.linspace(0.0, 0.9, 10), "brightness": np.linspace(0.0, 0.9, 10), "autocontrast": [0] * 10, "equalize": [0] * 10, "invert": [0] * 10, } # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand def rotate_with_fill(img, magnitude): rot = img.convert("RGBA").rotate(magnitude) return Image.composite( rot, Image.new("RGBA", rot.size, (128, ) * 4), rot ).convert(img.mode) func = { "shearX": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), Image.BICUBIC, fillcolor=fillcolor, ), "shearY": lambda img, magnitude: img.transform( img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), Image.BICUBIC, fillcolor=fillcolor, ), "translateX": lambda img, magnitude: img.transform( img.size, Image.AFFINE, ( 1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0 ), fillcolor=fillcolor, ), "translateY": lambda img, magnitude: img.transform( img.size, Image.AFFINE, ( 1, 0, 0, 0, 1, magnitude * img.size[1] * random. choice([-1, 1]) ), fillcolor=fillcolor, ), "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), "color": lambda img, magnitude: ImageEnhance.Color(img). enhance(1 + magnitude * random.choice([-1, 1])), "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), "contrast": lambda img, magnitude: ImageEnhance.Contrast(img). enhance(1 + magnitude * random.choice([-1, 1])), "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img). enhance(1 + magnitude * random.choice([-1, 1])), "brightness": lambda img, magnitude: ImageEnhance.Brightness(img). enhance(1 + magnitude * random.choice([-1, 1])), "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), "equalize": lambda img, magnitude: ImageOps.equalize(img), "invert": lambda img, magnitude: ImageOps.invert(img), } self.p1 = p1 self.operation1 = func[operation1] self.magnitude1 = ranges[operation1][magnitude_idx1] self.p2 = p2 self.operation2 = func[operation2] self.magnitude2 = ranges[operation2][magnitude_idx2] def __call__(self, img): if random.random() < self.p1: img = self.operation1(img, self.magnitude1) if random.random() < self.p2: img = self.operation2(img, self.magnitude2) return img ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py ================================================ """ Credit to 1) https://github.com/ildoonet/pytorch-randaugment 2) https://github.com/kakaobrain/fast-autoaugment """ import numpy as np import random import PIL import torch import PIL.ImageOps import PIL.ImageDraw import PIL.ImageEnhance from PIL import Image def ShearX(img, v): assert -0.3 <= v <= 0.3 if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) def ShearY(img, v): assert -0.3 <= v <= 0.3 if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert -0.45 <= v <= 0.45 if random.random() > 0.5: v = -v v = v * img.size[0] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert 0 <= v if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert -0.45 <= v <= 0.45 if random.random() > 0.5: v = -v v = v * img.size[1] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert 0 <= v if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) def Rotate(img, v): assert -30 <= v <= 30 if random.random() > 0.5: v = -v return img.rotate(v) def AutoContrast(img, _): return PIL.ImageOps.autocontrast(img) def Invert(img, _): return PIL.ImageOps.invert(img) def Equalize(img, _): return PIL.ImageOps.equalize(img) def Flip(img, _): return PIL.ImageOps.mirror(img) def Solarize(img, v): assert 0 <= v <= 256 return PIL.ImageOps.solarize(img, v) def SolarizeAdd(img, addition=0, threshold=128): img_np = np.array(img).astype(np.int) img_np = img_np + addition img_np = np.clip(img_np, 0, 255) img_np = img_np.astype(np.uint8) img = Image.fromarray(img_np) return PIL.ImageOps.solarize(img, threshold) def Posterize(img, v): assert 4 <= v <= 8 v = int(v) return PIL.ImageOps.posterize(img, v) def Contrast(img, v): assert 0.0 <= v <= 2.0 return PIL.ImageEnhance.Contrast(img).enhance(v) def Color(img, v): assert 0.0 <= v <= 2.0 return PIL.ImageEnhance.Color(img).enhance(v) def Brightness(img, v): assert 0.0 <= v <= 2.0 return PIL.ImageEnhance.Brightness(img).enhance(v) def Sharpness(img, v): assert 0.0 <= v <= 2.0 return PIL.ImageEnhance.Sharpness(img).enhance(v) def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] assert 0.0 <= v <= 0.2 if v <= 0.0: return img v = v * img.size[0] return CutoutAbs(img, v) def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] # assert 0 <= v <= 20 if v < 0: return img w, h = img.size x0 = np.random.uniform(w) y0 = np.random.uniform(h) x0 = int(max(0, x0 - v/2.0)) y0 = int(max(0, y0 - v/2.0)) x1 = min(w, x0 + v) y1 = min(h, y0 + v) xy = (x0, y0, x1, y1) color = (125, 123, 114) # color = (0, 0, 0) img = img.copy() PIL.ImageDraw.Draw(img).rectangle(xy, color) return img def SamplePairing(imgs): # [0, 0.4] def f(img1, v): i = np.random.choice(len(imgs)) img2 = PIL.Image.fromarray(imgs[i]) return PIL.Image.blend(img1, img2, v) return f def Identity(img, v): return img class Lighting: """Lighting noise (AlexNet - style PCA - based noise).""" def __init__(self, alphastd, eigval, eigvec): self.alphastd = alphastd self.eigval = torch.Tensor(eigval) self.eigvec = torch.Tensor(eigvec) def __call__(self, img): if self.alphastd == 0: return img alpha = img.new().resize_(3).normal_(0, self.alphastd) rgb = ( self.eigvec.type_as(img).clone().mul( alpha.view(1, 3).expand(3, 3) ).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze() ) return img.add(rgb.view(3, 1, 1).expand_as(img)) class CutoutDefault: """ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py """ def __init__(self, length): self.length = length def __call__(self, img): h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1:y2, x1:x2] = 0.0 mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img def randaugment_list(): # 16 oeprations and their ranges # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 # augs = [ # (Identity, 0., 1.0), # (ShearX, 0., 0.3), # 0 # (ShearY, 0., 0.3), # 1 # (TranslateX, 0., 0.33), # 2 # (TranslateY, 0., 0.33), # 3 # (Rotate, 0, 30), # 4 # (AutoContrast, 0, 1), # 5 # (Invert, 0, 1), # 6 # (Equalize, 0, 1), # 7 # (Solarize, 0, 110), # 8 # (Posterize, 4, 8), # 9 # # (Contrast, 0.1, 1.9), # 10 # (Color, 0.1, 1.9), # 11 # (Brightness, 0.1, 1.9), # 12 # (Sharpness, 0.1, 1.9), # 13 # # (Cutout, 0, 0.2), # 14 # # (SamplePairing(imgs), 0, 0.4) # 15 # ] # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 augs = [ (AutoContrast, 0, 1), (Equalize, 0, 1), (Invert, 0, 1), (Rotate, 0, 30), (Posterize, 4, 8), (Solarize, 0, 256), (SolarizeAdd, 0, 110), (Color, 0.1, 1.9), (Contrast, 0.1, 1.9), (Brightness, 0.1, 1.9), (Sharpness, 0.1, 1.9), (ShearX, 0.0, 0.3), (ShearY, 0.0, 0.3), (CutoutAbs, 0, 40), (TranslateXabs, 0.0, 100), (TranslateYabs, 0.0, 100), ] return augs def randaugment_list2(): augs = [ (AutoContrast, 0, 1), (Brightness, 0.1, 1.9), (Color, 0.1, 1.9), (Contrast, 0.1, 1.9), (Equalize, 0, 1), (Identity, 0, 1), (Invert, 0, 1), (Posterize, 4, 8), (Rotate, -30, 30), (Sharpness, 0.1, 1.9), (ShearX, -0.3, 0.3), (ShearY, -0.3, 0.3), (Solarize, 0, 256), (TranslateX, -0.3, 0.3), (TranslateY, -0.3, 0.3), ] return augs def fixmatch_list(): # https://arxiv.org/abs/2001.07685 augs = [ (AutoContrast, 0, 1), (Brightness, 0.05, 0.95), (Color, 0.05, 0.95), (Contrast, 0.05, 0.95), (Equalize, 0, 1), (Identity, 0, 1), (Posterize, 4, 8), (Rotate, -30, 30), (Sharpness, 0.05, 0.95), (ShearX, -0.3, 0.3), (ShearY, -0.3, 0.3), (Solarize, 0, 256), (TranslateX, -0.3, 0.3), (TranslateY, -0.3, 0.3), ] return augs class RandAugment: def __init__(self, n=2, m=10): assert 0 <= m <= 30 self.n = n self.m = m self.augment_list = randaugment_list() def __call__(self, img): ops = random.choices(self.augment_list, k=self.n) for op, minval, maxval in ops: val = (self.m / 30) * (maxval-minval) + minval img = op(img, val) return img class RandAugment2: def __init__(self, n=2, p=0.6): self.n = n self.p = p self.augment_list = randaugment_list2() def __call__(self, img): ops = random.choices(self.augment_list, k=self.n) for op, minval, maxval in ops: if random.random() > self.p: continue m = random.random() val = m * (maxval-minval) + minval img = op(img, val) return img class RandAugmentFixMatch: def __init__(self, n=2): self.n = n self.augment_list = fixmatch_list() def __call__(self, img): ops = random.choices(self.augment_list, k=self.n) for op, minval, maxval in ops: m = random.random() val = m * (maxval-minval) + minval img = op(img, val) return img ================================================ FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py ================================================ import numpy as np import random import torch from PIL import Image from torchvision.transforms import ( Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter, RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop, RandomHorizontalFlip ) from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch AVAI_CHOICES = [ "random_flip", "random_resized_crop", "normalize", "instance_norm", "random_crop", "random_translation", "center_crop", # This has become a default operation for test "cutout", "imagenet_policy", "cifar10_policy", "svhn_policy", "randaugment", "randaugment_fixmatch", "randaugment2", "gaussian_noise", "colorjitter", "randomgrayscale", "gaussian_blur", ] INTERPOLATION_MODES = { "bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC, "nearest": Image.NEAREST, } class Random2DTranslation: """Given an image of (height, width), we resize it to (height*1.125, width*1.125), and then perform random cropping. Args: height (int): target image height. width (int): target image width. p (float, optional): probability that this operation takes place. Default is 0.5. interpolation (int, optional): desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): self.height = height self.width = width self.p = p self.interpolation = interpolation def __call__(self, img): if random.uniform(0, 1) > self.p: return img.resize((self.width, self.height), self.interpolation) new_width = int(round(self.width * 1.125)) new_height = int(round(self.height * 1.125)) resized_img = img.resize((new_width, new_height), self.interpolation) x_maxrange = new_width - self.width y_maxrange = new_height - self.height x1 = int(round(random.uniform(0, x_maxrange))) y1 = int(round(random.uniform(0, y_maxrange))) croped_img = resized_img.crop( (x1, y1, x1 + self.width, y1 + self.height) ) return croped_img class InstanceNormalization: """Normalize data using per-channel mean and standard deviation. Reference: - Ulyanov et al. Instance normalization: The missing in- gredient for fast stylization. ArXiv 2016. - Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation. ICLR 2018. """ def __init__(self, eps=1e-8): self.eps = eps def __call__(self, img): C, H, W = img.shape img_re = img.reshape(C, H * W) mean = img_re.mean(1).view(C, 1, 1) std = img_re.std(1).view(C, 1, 1) return (img-mean) / (std + self.eps) class Cutout: """Randomly mask out one or more patches from an image. https://github.com/uoguelph-mlrg/Cutout Args: n_holes (int, optional): number of patches to cut out of each image. Default is 1. length (int, optinal): length (in pixels) of each square patch. Default is 16. """ def __init__(self, n_holes=1, length=16): self.n_holes = n_holes self.length = length def __call__(self, img): """ Args: img (Tensor): tensor image of size (C, H, W). Returns: Tensor: image with n_holes of dimension length x length cut out of it. """ h = img.size(1) w = img.size(2) mask = np.ones((h, w), np.float32) for n in range(self.n_holes): y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1:y2, x1:x2] = 0.0 mask = torch.from_numpy(mask) mask = mask.expand_as(img) return img * mask class GaussianNoise: """Add gaussian noise.""" def __init__(self, mean=0, std=0.15, p=0.5): self.mean = mean self.std = std self.p = p def __call__(self, img): if random.uniform(0, 1) > self.p: return img noise = torch.randn(img.size()) * self.std + self.mean return img + noise def build_transform(cfg, is_train=True, choices=None): """Build transformation function. Args: cfg (CfgNode): config. is_train (bool, optional): for training (True) or test (False). Default is True. choices (list, optional): list of strings which will overwrite cfg.INPUT.TRANSFORMS if given. Default is None. """ if cfg.INPUT.NO_TRANSFORM: print("Note: no transform is applied!") return None if choices is None: choices = cfg.INPUT.TRANSFORMS for choice in choices: assert choice in AVAI_CHOICES target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}" normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) if is_train: return _build_transform_train(cfg, choices, target_size, normalize) else: return _build_transform_test(cfg, choices, target_size, normalize) def _build_transform_train(cfg, choices, target_size, normalize): print("Building transform_train") tfm_train = [] interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] # Make sure the image size matches the target size conditions = [] conditions += ["random_crop" not in choices] conditions += ["random_resized_crop" not in choices] if all(conditions): print(f"+ resize to {target_size}") tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)] if "random_translation" in choices: print("+ random translation") tfm_train += [ Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1]) ] if "random_crop" in choices: crop_padding = cfg.INPUT.CROP_PADDING print("+ random crop (padding = {})".format(crop_padding)) tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)] if "random_resized_crop" in choices: print(f"+ random resized crop (size={cfg.INPUT.SIZE})") tfm_train += [ RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode) ] if "center_crop" in choices: print(f"+ center crop (size={cfg.INPUT.SIZE})") tfm_train += [CenterCrop(cfg.INPUT.SIZE)] if "random_flip" in choices: print("+ random flip") tfm_train += [RandomHorizontalFlip()] if "imagenet_policy" in choices: print("+ imagenet policy") tfm_train += [ImageNetPolicy()] if "cifar10_policy" in choices: print("+ cifar10 policy") tfm_train += [CIFAR10Policy()] if "svhn_policy" in choices: print("+ svhn policy") tfm_train += [SVHNPolicy()] if "randaugment" in choices: n_ = cfg.INPUT.RANDAUGMENT_N m_ = cfg.INPUT.RANDAUGMENT_M print("+ randaugment (n={}, m={})".format(n_, m_)) tfm_train += [RandAugment(n_, m_)] if "randaugment_fixmatch" in choices: n_ = cfg.INPUT.RANDAUGMENT_N print("+ randaugment_fixmatch (n={})".format(n_)) tfm_train += [RandAugmentFixMatch(n_)] if "randaugment2" in choices: n_ = cfg.INPUT.RANDAUGMENT_N print("+ randaugment2 (n={})".format(n_)) tfm_train += [RandAugment2(n_)] if "colorjitter" in choices: print("+ color jitter") tfm_train += [ ColorJitter( brightness=cfg.INPUT.COLORJITTER_B, contrast=cfg.INPUT.COLORJITTER_C, saturation=cfg.INPUT.COLORJITTER_S, hue=cfg.INPUT.COLORJITTER_H, ) ] if "randomgrayscale" in choices: print("+ random gray scale") tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)] if "gaussian_blur" in choices: print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})") tfm_train += [ RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P) ] print("+ to torch tensor of range [0, 1]") tfm_train += [ToTensor()] if "cutout" in choices: cutout_n = cfg.INPUT.CUTOUT_N cutout_len = cfg.INPUT.CUTOUT_LEN print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len)) tfm_train += [Cutout(cutout_n, cutout_len)] if "normalize" in choices: print( "+ normalization (mean={}, " "std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD) ) tfm_train += [normalize] if "gaussian_noise" in choices: print( "+ gaussian noise (mean={}, std={})".format( cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD ) ) tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)] if "instance_norm" in choices: print("+ instance normalization") tfm_train += [InstanceNormalization()] tfm_train = Compose(tfm_train) return tfm_train def _build_transform_test(cfg, choices, target_size, normalize): print("Building transform_test") tfm_test = [] interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}") tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)] print(f"+ {target_size} center crop") tfm_test += [CenterCrop(cfg.INPUT.SIZE)] print("+ to torch tensor of range [0, 1]") tfm_test += [ToTensor()] if "normalize" in choices: print( "+ normalization (mean={}, " "std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD) ) tfm_test += [normalize] if "instance_norm" in choices: print("+ instance normalization") tfm_test += [InstanceNormalization()] tfm_test = Compose(tfm_test) return tfm_test ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/__init__.py ================================================ from .build import TRAINER_REGISTRY, build_trainer # isort:skip from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip from .da import * from .dg import * from .ssl import * ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/build.py ================================================ from dassl.utils import Registry, check_availability TRAINER_REGISTRY = Registry("TRAINER") def build_trainer(cfg): avai_trainers = TRAINER_REGISTRY.registered_names() check_availability(cfg.TRAINER.NAME, avai_trainers) if cfg.VERBOSE: print("Loading trainer: {}".format(cfg.TRAINER.NAME)) return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py ================================================ from .mcd import MCD from .mme import MME from .adda import ADDA from .dael import DAEL from .dann import DANN from .adabn import AdaBN from .m3sda import M3SDA from .source_only import SourceOnly from .self_ensembling import SelfEnsembling ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py ================================================ import torch from dassl.utils import check_isfile from dassl.engine import TRAINER_REGISTRY, TrainerXU @TRAINER_REGISTRY.register() class AdaBN(TrainerXU): """Adaptive Batch Normalization. https://arxiv.org/abs/1603.04779. """ def __init__(self, cfg): super().__init__(cfg) self.done_reset_bn_stats = False def check_cfg(self, cfg): assert check_isfile( cfg.MODEL.INIT_WEIGHTS ), "The weights of source model must be provided" def before_epoch(self): if not self.done_reset_bn_stats: for m in self.model.modules(): classname = m.__class__.__name__ if classname.find("BatchNorm") != -1: m.reset_running_stats() self.done_reset_bn_stats = True def forward_backward(self, batch_x, batch_u): input_u = batch_u["img"].to(self.device) with torch.no_grad(): self.model(input_u) return None ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adda.py ================================================ import copy import torch import torch.nn as nn from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import check_isfile, count_num_param, open_specified_layers from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.modeling import build_head @TRAINER_REGISTRY.register() class ADDA(TrainerXU): """Adversarial Discriminative Domain Adaptation. https://arxiv.org/abs/1702.05464. """ def __init__(self, cfg): super().__init__(cfg) self.open_layers = ["backbone"] if isinstance(self.model.head, nn.Module): self.open_layers.append("head") self.source_model = copy.deepcopy(self.model) self.source_model.eval() for param in self.source_model.parameters(): param.requires_grad_(False) self.build_critic() self.bce = nn.BCEWithLogitsLoss() def check_cfg(self, cfg): assert check_isfile( cfg.MODEL.INIT_WEIGHTS ), "The weights of source model must be provided" def build_critic(self): cfg = self.cfg print("Building critic network") fdim = self.model.fdim critic_body = build_head( "mlp", verbose=cfg.VERBOSE, in_features=fdim, hidden_layers=[fdim, fdim // 2], activation="leaky_relu", ) self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) print("# params: {:,}".format(count_num_param(self.critic))) self.critic.to(self.device) self.optim_c = build_optimizer(self.critic, cfg.OPTIM) self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) self.register_model("critic", self.critic, self.optim_c, self.sched_c) def forward_backward(self, batch_x, batch_u): open_specified_layers(self.model, self.open_layers) input_x, _, input_u = self.parse_batch_train(batch_x, batch_u) domain_x = torch.ones(input_x.shape[0], 1).to(self.device) domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) _, feat_x = self.source_model(input_x, return_feature=True) _, feat_u = self.model(input_u, return_feature=True) logit_xd = self.critic(feat_x) logit_ud = self.critic(feat_u.detach()) loss_critic = self.bce(logit_xd, domain_x) loss_critic += self.bce(logit_ud, domain_u) self.model_backward_and_update(loss_critic, "critic") logit_ud = self.critic(feat_u) loss_model = self.bce(logit_ud, 1 - domain_u) self.model_backward_and_update(loss_model, "model") loss_summary = { "loss_critic": loss_critic.item(), "loss_model": loss_model.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dael.py ================================================ import torch import torch.nn as nn from dassl.data import DataManager from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.engine.trainer import SimpleNet from dassl.data.transforms import build_transform from dassl.modeling.ops.utils import create_onehot class Experts(nn.Module): def __init__(self, n_source, fdim, num_classes): super().__init__() self.linears = nn.ModuleList( [nn.Linear(fdim, num_classes) for _ in range(n_source)] ) self.softmax = nn.Softmax(dim=1) def forward(self, i, x): x = self.linears[i](x) x = self.softmax(x) return x @TRAINER_REGISTRY.register() class DAEL(TrainerXU): """Domain Adaptive Ensemble Learning. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler" assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = dm.train_loader_x self.train_loader_u = dm.train_loader_u self.val_loader = dm.val_loader self.test_loader = dm.test_loader self.num_classes = dm.num_classes self.num_source_domains = dm.num_source_domains self.lab2cname = dm.lab2cname def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print("Building E") self.E = Experts(self.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print("# params: {:,}".format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model("E", self.E, self.optim_E, self.sched_E) def forward_backward(self, batch_x, batch_u): parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data input_x = torch.split(input_x, self.split_batch, 0) input_x2 = torch.split(input_x2, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # Generate pseudo label with torch.no_grad(): feat_u = self.F(input_u) pred_u = [] for k in range(self.num_source_domains): pred_uk = self.E(k, feat_u) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) # (B, K, C) # Get the highest probability and index (label) for each expert experts_max_p, experts_max_idx = pred_u.max(2) # (B, K) # Get the most confident expert max_expert_p, max_expert_idx = experts_max_p.max(1) # (B) pseudo_label_u = [] for i, experts_label in zip(max_expert_idx, experts_max_idx): pseudo_label_u.append(experts_label[i]) pseudo_label_u = torch.stack(pseudo_label_u, 0) pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes) pseudo_label_u = pseudo_label_u.to(self.device) label_u_mask = (max_expert_p >= self.conf_thre).float() loss_x = 0 loss_cr = 0 acc_x = 0 feat_x = [self.F(x) for x in input_x] feat_x2 = [self.F(x) for x in input_x2] feat_u2 = self.F(input_u2) for feat_xi, feat_x2i, label_xi, i in zip( feat_x, feat_x2, label_x, domain_x ): cr_s = [j for j in domain_x if j != i] # Learning expert pred_xi = self.E(i, feat_xi) loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean() expert_label_xi = pred_xi.detach() acc_x += compute_accuracy(pred_xi.detach(), label_xi.max(1)[1])[0].item() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat_x2i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain acc_x /= self.n_domain # Unsupervised loss pred_u = [] for k in range(self.num_source_domains): pred_uk = self.E(k, feat_u2) pred_uk = pred_uk.unsqueeze(1) pred_u.append(pred_uk) pred_u = torch.cat(pred_u, 1) pred_u = pred_u.mean(1) l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1) loss_u = (l_u * label_u_mask).mean() loss = 0 loss += loss_x loss += loss_cr loss += loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc_x": acc_x, "loss_cr": loss_cr.item(), "loss_u": loss_u.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"] input_x2 = batch_x["img2"] label_x = batch_x["label"] domain_x = batch_x["domain"] input_u = batch_u["img"] input_u2 = batch_u["img2"] label_x = create_onehot(label_x, self.num_classes) input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) return input_x, input_x2, label_x, domain_x, input_u, input_u2 def model_inference(self, input): f = self.F(input) p = [] for k in range(self.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dann.py ================================================ import numpy as np import torch import torch.nn as nn from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.modeling import build_head from dassl.modeling.ops import ReverseGrad @TRAINER_REGISTRY.register() class DANN(TrainerXU): """Domain-Adversarial Neural Networks. https://arxiv.org/abs/1505.07818. """ def __init__(self, cfg): super().__init__(cfg) self.build_critic() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() def build_critic(self): cfg = self.cfg print("Building critic network") fdim = self.model.fdim critic_body = build_head( "mlp", verbose=cfg.VERBOSE, in_features=fdim, hidden_layers=[fdim, fdim], activation="leaky_relu", ) self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1)) print("# params: {:,}".format(count_num_param(self.critic))) self.critic.to(self.device) self.optim_c = build_optimizer(self.critic, cfg.OPTIM) self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) self.register_model("critic", self.critic, self.optim_c, self.sched_c) self.revgrad = ReverseGrad() def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) domain_x = torch.ones(input_x.shape[0], 1).to(self.device) domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) global_step = self.batch_idx + self.epoch * self.num_batches progress = global_step / (self.max_epoch * self.num_batches) lmda = 2 / (1 + np.exp(-10 * progress)) - 1 logit_x, feat_x = self.model(input_x, return_feature=True) _, feat_u = self.model(input_u, return_feature=True) loss_x = self.ce(logit_x, label_x) feat_x = self.revgrad(feat_x, grad_scaling=lmda) feat_u = self.revgrad(feat_u, grad_scaling=lmda) output_xd = self.critic(feat_x) output_ud = self.critic(feat_u) loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u) loss = loss_x + loss_d self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(logit_x, label_x)[0].item(), "loss_d": loss_d.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.engine.trainer import SimpleNet class PairClassifiers(nn.Module): def __init__(self, fdim, num_classes): super().__init__() self.c1 = nn.Linear(fdim, num_classes) self.c2 = nn.Linear(fdim, num_classes) def forward(self, x): z1 = self.c1(x) if not self.training: return z1 z2 = self.c2(x) return z1, z2 @TRAINER_REGISTRY.register() class M3SDA(TrainerXU): """Moment Matching for Multi-Source Domain Adaptation. https://arxiv.org/abs/1812.01754. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F self.lmda = cfg.TRAINER.M3SDA.LMDA def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler" assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print("Building C") self.C = nn.ModuleList( [ PairClassifiers(fdim, self.num_classes) for _ in range(self.num_source_domains) ] ) self.C.to(self.device) print("# params: {:,}".format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model("C", self.C, self.optim_C, self.sched_C) def forward_backward(self, batch_x, batch_u): parsed = self.parse_batch_train(batch_x, batch_u) input_x, label_x, domain_x, input_u = parsed input_x = torch.split(input_x, self.split_batch, 0) label_x = torch.split(label_x, self.split_batch, 0) domain_x = torch.split(domain_x, self.split_batch, 0) domain_x = [d[0].item() for d in domain_x] # Step A loss_x = 0 feat_x = [] for x, y, d in zip(input_x, label_x, domain_x): f = self.F(x) z1, z2 = self.C[d](f) loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y) feat_x.append(f) loss_x /= self.n_domain feat_u = self.F(input_u) loss_msda = self.moment_distance(feat_x, feat_u) loss_step_A = loss_x + loss_msda * self.lmda self.model_backward_and_update(loss_step_A) # Step B with torch.no_grad(): feat_u = self.F(input_u) loss_x, loss_dis = 0, 0 for x, y, d in zip(input_x, label_x, domain_x): with torch.no_grad(): f = self.F(x) z1, z2 = self.C[d](f) loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y) z1, z2 = self.C[d](feat_u) p1 = F.softmax(z1, 1) p2 = F.softmax(z2, 1) loss_dis += self.discrepancy(p1, p2) loss_x /= self.n_domain loss_dis /= self.n_domain loss_step_B = loss_x - loss_dis self.model_backward_and_update(loss_step_B, "C") # Step C for _ in range(self.n_step_F): feat_u = self.F(input_u) loss_dis = 0 for d in domain_x: z1, z2 = self.C[d](feat_u) p1 = F.softmax(z1, 1) p2 = F.softmax(z2, 1) loss_dis += self.discrepancy(p1, p2) loss_dis /= self.n_domain loss_step_C = loss_dis self.model_backward_and_update(loss_step_C, "F") loss_summary = { "loss_step_A": loss_step_A.item(), "loss_step_B": loss_step_B.item(), "loss_step_C": loss_step_C.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def moment_distance(self, x, u): # x (list): a list of feature matrix. # u (torch.Tensor): feature matrix. x_mean = [xi.mean(0) for xi in x] u_mean = u.mean(0) dist1 = self.pairwise_distance(x_mean, u_mean) x_var = [xi.var(0) for xi in x] u_var = u.var(0) dist2 = self.pairwise_distance(x_var, u_var) return (dist1+dist2) / 2 def pairwise_distance(self, x, u): # x (list): a list of feature vector. # u (torch.Tensor): feature vector. dist = 0 count = 0 for xi in x: dist += self.euclidean(xi, u) count += 1 for i in range(len(x) - 1): for j in range(i + 1, len(x)): dist += self.euclidean(x[i], x[j]) count += 1 return dist / count def euclidean(self, input1, input2): return ((input1 - input2)**2).sum().sqrt() def discrepancy(self, y1, y2): return (y1 - y2).abs().mean() def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"] label_x = batch_x["label"] domain_x = batch_x["domain"] input_u = batch_u["img"] input_x = input_x.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) return input_x, label_x, domain_x, input_u def model_inference(self, input): f = self.F(input) p = 0 for C_i in self.C: z = C_i(f) p += F.softmax(z, 1) p = p / len(self.C) return p ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.engine.trainer import SimpleNet @TRAINER_REGISTRY.register() class MCD(TrainerXU): """Maximum Classifier Discrepancy. https://arxiv.org/abs/1712.02560. """ def __init__(self, cfg): super().__init__(cfg) self.n_step_F = cfg.TRAINER.MCD.N_STEP_F def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print("Building C1") self.C1 = nn.Linear(fdim, self.num_classes) self.C1.to(self.device) print("# params: {:,}".format(count_num_param(self.C1))) self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM) self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM) self.register_model("C1", self.C1, self.optim_C1, self.sched_C1) print("Building C2") self.C2 = nn.Linear(fdim, self.num_classes) self.C2.to(self.device) print("# params: {:,}".format(count_num_param(self.C2))) self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM) self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM) self.register_model("C2", self.C2, self.optim_C2, self.sched_C2) def forward_backward(self, batch_x, batch_u): parsed = self.parse_batch_train(batch_x, batch_u) input_x, label_x, input_u = parsed # Step A feat_x = self.F(input_x) logit_x1 = self.C1(feat_x) logit_x2 = self.C2(feat_x) loss_x1 = F.cross_entropy(logit_x1, label_x) loss_x2 = F.cross_entropy(logit_x2, label_x) loss_step_A = loss_x1 + loss_x2 self.model_backward_and_update(loss_step_A) # Step B with torch.no_grad(): feat_x = self.F(input_x) logit_x1 = self.C1(feat_x) logit_x2 = self.C2(feat_x) loss_x1 = F.cross_entropy(logit_x1, label_x) loss_x2 = F.cross_entropy(logit_x2, label_x) loss_x = loss_x1 + loss_x2 with torch.no_grad(): feat_u = self.F(input_u) pred_u1 = F.softmax(self.C1(feat_u), 1) pred_u2 = F.softmax(self.C2(feat_u), 1) loss_dis = self.discrepancy(pred_u1, pred_u2) loss_step_B = loss_x - loss_dis self.model_backward_and_update(loss_step_B, ["C1", "C2"]) # Step C for _ in range(self.n_step_F): feat_u = self.F(input_u) pred_u1 = F.softmax(self.C1(feat_u), 1) pred_u2 = F.softmax(self.C2(feat_u), 1) loss_step_C = self.discrepancy(pred_u1, pred_u2) self.model_backward_and_update(loss_step_C, "F") loss_summary = { "loss_step_A": loss_step_A.item(), "loss_step_B": loss_step_B.item(), "loss_step_C": loss_step_C.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def discrepancy(self, y1, y2): return (y1 - y2).abs().mean() def model_inference(self, input): feat = self.F(input) return self.C1(feat) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mme.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.modeling.ops import ReverseGrad from dassl.engine.trainer import SimpleNet class Prototypes(nn.Module): def __init__(self, fdim, num_classes, temp=0.05): super().__init__() self.prototypes = nn.Linear(fdim, num_classes, bias=False) self.temp = temp def forward(self, x): x = F.normalize(x, p=2, dim=1) out = self.prototypes(x) out = out / self.temp return out @TRAINER_REGISTRY.register() class MME(TrainerXU): """Minimax Entropy. https://arxiv.org/abs/1904.06487. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.MME.LMDA def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) print("Building C") self.C = Prototypes(self.F.fdim, self.num_classes) self.C.to(self.device) print("# params: {:,}".format(count_num_param(self.C))) self.optim_C = build_optimizer(self.C, cfg.OPTIM) self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) self.register_model("C", self.C, self.optim_C, self.sched_C) self.revgrad = ReverseGrad() def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) feat_x = self.F(input_x) logit_x = self.C(feat_x) loss_x = F.cross_entropy(logit_x, label_x) self.model_backward_and_update(loss_x) feat_u = self.F(input_u) feat_u = self.revgrad(feat_u) logit_u = self.C(feat_u) prob_u = F.softmax(logit_u, 1) loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean() self.model_backward_and_update(loss_u * self.lmda) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(logit_x, label_x)[0].item(), "loss_u": loss_u.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def model_inference(self, input): return self.C(self.F(input)) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py ================================================ import copy from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update @TRAINER_REGISTRY.register() class SelfEnsembling(TrainerXU): """Self-ensembling for visual domain adaptation. https://arxiv.org/abs/1706.05208. """ def __init__(self, cfg): super().__init__(cfg) self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA self.conf_thre = cfg.TRAINER.SE.CONF_THRE self.rampup = cfg.TRAINER.SE.RAMPUP self.teacher = copy.deepcopy(self.model) self.teacher.train() for param in self.teacher.parameters(): param.requires_grad_(False) def check_cfg(self, cfg): assert cfg.DATALOADER.K_TRANSFORMS == 2 def forward_backward(self, batch_x, batch_u): global_step = self.batch_idx + self.epoch * self.num_batches parsed = self.parse_batch_train(batch_x, batch_u) input_x, label_x, input_u1, input_u2 = parsed logit_x = self.model(input_x) loss_x = F.cross_entropy(logit_x, label_x) prob_u = F.softmax(self.model(input_u1), 1) t_prob_u = F.softmax(self.teacher(input_u2), 1) loss_u = ((prob_u - t_prob_u)**2).sum(1) if self.conf_thre: max_prob = t_prob_u.max(1)[0] mask = (max_prob > self.conf_thre).float() loss_u = (loss_u * mask).mean() else: weight_u = sigmoid_rampup(global_step, self.rampup) loss_u = loss_u.mean() * weight_u loss = loss_x + loss_u self.model_backward_and_update(loss) ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) ema_model_update(self.model, self.teacher, ema_alpha) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(logit_x, label_x)[0].item(), "loss_u": loss_u.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"][0] label_x = batch_x["label"] input_u = batch_u["img"] input_u1, input_u2 = input_u input_x = input_x.to(self.device) label_x = label_x.to(self.device) input_u1 = input_u1.to(self.device) input_u2 = input_u2.to(self.device) return input_x, label_x, input_u1, input_u2 ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py ================================================ from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy @TRAINER_REGISTRY.register() class SourceOnly(TrainerXU): """Baseline model for domain adaptation, which is trained using source data only. """ def forward_backward(self, batch_x, batch_u): input, label = self.parse_batch_train(batch_x, batch_u) output = self.model(input) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input = batch_x["img"] label = batch_x["label"] input = input.to(self.device) label = label.to(self.device) return input, label ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py ================================================ from .ddaig import DDAIG from .daeldg import DAELDG from .vanilla import Vanilla from .crossgrad import CrossGrad ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py ================================================ import torch from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.engine.trainer import SimpleNet @TRAINER_REGISTRY.register() class CrossGrad(TrainerX): """Cross-gradient training. https://arxiv.org/abs/1804.10745. """ def __init__(self, cfg): super().__init__(cfg) self.eps_f = cfg.TRAINER.CG.EPS_F self.eps_d = cfg.TRAINER.CG.EPS_D self.alpha_f = cfg.TRAINER.CG.ALPHA_F self.alpha_d = cfg.TRAINER.CG.ALPHA_D def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) print("Building D") self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) self.D.to(self.device) print("# params: {:,}".format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model("D", self.D, self.optim_D, self.sched_D) def forward_backward(self, batch): input, label, domain = self.parse_batch_train(batch) input.requires_grad = True # Compute domain perturbation loss_d = F.cross_entropy(self.D(input), domain) loss_d.backward() grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1) input_d = input.data + self.eps_f * grad_d # Compute label perturbation input.grad.data.zero_() loss_f = F.cross_entropy(self.F(input), label) loss_f.backward() grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1) input_f = input.data + self.eps_d * grad_f input = input.detach() # Update label net loss_f1 = F.cross_entropy(self.F(input), label) loss_f2 = F.cross_entropy(self.F(input_d), label) loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2 self.model_backward_and_update(loss_f, "F") # Update domain net loss_d1 = F.cross_entropy(self.D(input), domain) loss_d2 = F.cross_entropy(self.D(input_f), domain) loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2 self.model_backward_and_update(loss_d, "D") loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def model_inference(self, input): return self.F(input) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py ================================================ import torch import torch.nn as nn from dassl.data import DataManager from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.engine.trainer import SimpleNet from dassl.data.transforms import build_transform from dassl.modeling.ops.utils import create_onehot class Experts(nn.Module): def __init__(self, n_source, fdim, num_classes): super().__init__() self.linears = nn.ModuleList( [nn.Linear(fdim, num_classes) for _ in range(n_source)] ) self.softmax = nn.Softmax(dim=1) def forward(self, i, x): x = self.linears[i](x) x = self.softmax(x) return x @TRAINER_REGISTRY.register() class DAELDG(TrainerX): """Domain Adaptive Ensemble Learning. DG version: only use labeled source data. https://arxiv.org/abs/2003.07325. """ def __init__(self, cfg): super().__init__(cfg) n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE if n_domain <= 0: n_domain = self.num_source_domains self.split_batch = batch_size // n_domain self.n_domain = n_domain self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE def check_cfg(self, cfg): assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler" assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = dm.train_loader_x self.train_loader_u = dm.train_loader_u self.val_loader = dm.val_loader self.test_loader = dm.test_loader self.num_classes = dm.num_classes self.num_source_domains = dm.num_source_domains self.lab2cname = dm.lab2cname def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, 0) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) fdim = self.F.fdim print("Building E") self.E = Experts(self.num_source_domains, fdim, self.num_classes) self.E.to(self.device) print("# params: {:,}".format(count_num_param(self.E))) self.optim_E = build_optimizer(self.E, cfg.OPTIM) self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM) self.register_model("E", self.E, self.optim_E, self.sched_E) def forward_backward(self, batch): parsed_data = self.parse_batch_train(batch) input, input2, label, domain = parsed_data input = torch.split(input, self.split_batch, 0) input2 = torch.split(input2, self.split_batch, 0) label = torch.split(label, self.split_batch, 0) domain = torch.split(domain, self.split_batch, 0) domain = [d[0].item() for d in domain] loss_x = 0 loss_cr = 0 acc = 0 feat = [self.F(x) for x in input] feat2 = [self.F(x) for x in input2] for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain): cr_s = [j for j in domain if j != i] # Learning expert pred_i = self.E(i, feat_i) loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean() expert_label_i = pred_i.detach() acc += compute_accuracy(pred_i.detach(), label_i.max(1)[1])[0].item() # Consistency regularization cr_pred = [] for j in cr_s: pred_j = self.E(j, feat2_i) pred_j = pred_j.unsqueeze(1) cr_pred.append(pred_j) cr_pred = torch.cat(cr_pred, 1) cr_pred = cr_pred.mean(1) loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean() loss_x /= self.n_domain loss_cr /= self.n_domain acc /= self.n_domain loss = 0 loss += loss_x loss += loss_cr self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc": acc, "loss_cr": loss_cr.item() } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] input2 = batch["img2"] label = batch["label"] domain = batch["domain"] label = create_onehot(label, self.num_classes) input = input.to(self.device) input2 = input2.to(self.device) label = label.to(self.device) return input, input2, label, domain def model_inference(self, input): f = self.F(input) p = [] for k in range(self.num_source_domains): p_k = self.E(k, f) p_k = p_k.unsqueeze(1) p.append(p_k) p = torch.cat(p, 1) p = p.mean(1) return p ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py ================================================ import torch from torch.nn import functional as F from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import count_num_param from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.modeling import build_network from dassl.engine.trainer import SimpleNet @TRAINER_REGISTRY.register() class DDAIG(TrainerX): """Deep Domain-Adversarial Image Generation. https://arxiv.org/abs/2003.06054. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.DDAIG.LMDA self.clamp = cfg.TRAINER.DDAIG.CLAMP self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX self.warmup = cfg.TRAINER.DDAIG.WARMUP self.alpha = cfg.TRAINER.DDAIG.ALPHA def build_model(self): cfg = self.cfg print("Building F") self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) self.F.to(self.device) print("# params: {:,}".format(count_num_param(self.F))) self.optim_F = build_optimizer(self.F, cfg.OPTIM) self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) self.register_model("F", self.F, self.optim_F, self.sched_F) print("Building D") self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) self.D.to(self.device) print("# params: {:,}".format(count_num_param(self.D))) self.optim_D = build_optimizer(self.D, cfg.OPTIM) self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) self.register_model("D", self.D, self.optim_D, self.sched_D) print("Building G") self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE) self.G.to(self.device) print("# params: {:,}".format(count_num_param(self.G))) self.optim_G = build_optimizer(self.G, cfg.OPTIM) self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM) self.register_model("G", self.G, self.optim_G, self.sched_G) def forward_backward(self, batch): input, label, domain = self.parse_batch_train(batch) ############# # Update G ############# input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) loss_g = 0 # Minimize label loss loss_g += F.cross_entropy(self.F(input_p), label) # Maximize domain loss loss_g -= F.cross_entropy(self.D(input_p), domain) self.model_backward_and_update(loss_g, "G") # Perturb data with new G with torch.no_grad(): input_p = self.G(input, lmda=self.lmda) if self.clamp: input_p = torch.clamp( input_p, min=self.clamp_min, max=self.clamp_max ) ############# # Update F ############# loss_f = F.cross_entropy(self.F(input), label) if (self.epoch + 1) > self.warmup: loss_fp = F.cross_entropy(self.F(input_p), label) loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp self.model_backward_and_update(loss_f, "F") ############# # Update D ############# loss_d = F.cross_entropy(self.D(input), domain) self.model_backward_and_update(loss_d, "D") loss_summary = { "loss_g": loss_g.item(), "loss_f": loss_f.item(), "loss_d": loss_d.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def model_inference(self, input): return self.F(input) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py ================================================ from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy @TRAINER_REGISTRY.register() class Vanilla(TrainerX): """Vanilla baseline.""" def forward_backward(self, batch): input, label = self.parse_batch_train(batch) output = self.model(input) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py ================================================ from .entmin import EntMin from .fixmatch import FixMatch from .mixmatch import MixMatch from .mean_teacher import MeanTeacher from .sup_baseline import SupBaseline ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py ================================================ import torch from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy @TRAINER_REGISTRY.register() class EntMin(TrainerXU): """Entropy Minimization. http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf. """ def __init__(self, cfg): super().__init__(cfg) self.lmda = cfg.TRAINER.ENTMIN.LMDA def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) output_x = self.model(input_x) loss_x = F.cross_entropy(output_x, label_x) output_u = F.softmax(self.model(input_u), 1) loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean() loss = loss_x + loss_u * self.lmda self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(output_x, label_x)[0].item(), "loss_u": loss_u.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py ================================================ import torch from torch.nn import functional as F from dassl.data import DataManager from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.data.transforms import build_transform @TRAINER_REGISTRY.register() class FixMatch(TrainerXU): """FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. https://arxiv.org/abs/2001.07685. """ def __init__(self, cfg): super().__init__(cfg) self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE def check_cfg(self, cfg): assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def assess_y_pred_quality(self, y_pred, y_true, mask): n_masked_correct = (y_pred.eq(y_true).float() * mask).sum() acc_thre = n_masked_correct / (mask.sum() + 1e-5) acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy keep_rate = mask.sum() / mask.numel() output = { "acc_thre": acc_thre, "acc_raw": acc_raw, "keep_rate": keep_rate } return output def forward_backward(self, batch_x, batch_u): parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data input_u = torch.cat([input_x, input_u], 0) input_u2 = torch.cat([input_x2, input_u2], 0) n_x = input_x.size(0) # Generate pseudo labels with torch.no_grad(): output_u = F.softmax(self.model(input_u), 1) max_prob, label_u_pred = output_u.max(1) mask_u = (max_prob >= self.conf_thre).float() # Evaluate pseudo labels' accuracy y_u_pred_stats = self.assess_y_pred_quality( label_u_pred[n_x:], label_u, mask_u[n_x:] ) # Supervised loss output_x = self.model(input_x) loss_x = F.cross_entropy(output_x, label_x) # Unsupervised loss output_u = self.model(input_u2) loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none") loss_u = (loss_u * mask_u).mean() loss = loss_x + loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(output_x, label_x)[0].item(), "loss_u": loss_u.item(), "y_u_pred_acc_raw": y_u_pred_stats["acc_raw"], "y_u_pred_acc_thre": y_u_pred_stats["acc_thre"], "y_u_pred_keep": y_u_pred_stats["keep_rate"], } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"] input_x2 = batch_x["img2"] label_x = batch_x["label"] input_u = batch_u["img"] input_u2 = batch_u["img2"] # label_u is used only for evaluating pseudo labels' accuracy label_u = batch_u["label"] input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) label_u = label_u.to(self.device) return input_x, input_x2, label_x, input_u, input_u2, label_u ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py ================================================ import copy from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update @TRAINER_REGISTRY.register() class MeanTeacher(TrainerXU): """Mean teacher. https://arxiv.org/abs/1703.01780. """ def __init__(self, cfg): super().__init__(cfg) self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA self.rampup = cfg.TRAINER.MEANTEA.RAMPUP self.teacher = copy.deepcopy(self.model) self.teacher.train() for param in self.teacher.parameters(): param.requires_grad_(False) def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) logit_x = self.model(input_x) loss_x = F.cross_entropy(logit_x, label_x) target_u = F.softmax(self.teacher(input_u), 1) prob_u = F.softmax(self.model(input_u), 1) loss_u = ((prob_u - target_u)**2).sum(1).mean() weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup) loss = loss_x + loss_u*weight_u self.model_backward_and_update(loss) global_step = self.batch_idx + self.epoch * self.num_batches ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) ema_model_update(self.model, self.teacher, ema_alpha) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(logit_x, label_x)[0].item(), "loss_u": loss_u.item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py ================================================ import torch from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.modeling.ops import mixup from dassl.modeling.ops.utils import ( sharpen_prob, create_onehot, linear_rampup, shuffle_index ) @TRAINER_REGISTRY.register() class MixMatch(TrainerXU): """MixMatch: A Holistic Approach to Semi-Supervised Learning. https://arxiv.org/abs/1905.02249. """ def __init__(self, cfg): super().__init__(cfg) self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U self.temp = cfg.TRAINER.MIXMATCH.TEMP self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP def check_cfg(self, cfg): assert cfg.DATALOADER.K_TRANSFORMS > 1 def forward_backward(self, batch_x, batch_u): input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) num_x = input_x.shape[0] global_step = self.batch_idx + self.epoch * self.num_batches weight_u = self.weight_u * linear_rampup(global_step, self.rampup) # Generate pseudo-label for unlabeled data with torch.no_grad(): output_u = 0 for input_ui in input_u: output_ui = F.softmax(self.model(input_ui), 1) output_u += output_ui output_u /= len(input_u) label_u = sharpen_prob(output_u, self.temp) label_u = [label_u] * len(input_u) label_u = torch.cat(label_u, 0) input_u = torch.cat(input_u, 0) # Combine and shuffle labeled and unlabeled data input_xu = torch.cat([input_x, input_u], 0) label_xu = torch.cat([label_x, label_u], 0) input_xu, label_xu = shuffle_index(input_xu, label_xu) # Mixup input_x, label_x = mixup( input_x, input_xu[:num_x], label_x, label_xu[:num_x], self.beta, preserve_order=True, ) input_u, label_u = mixup( input_u, input_xu[num_x:], label_u, label_xu[num_x:], self.beta, preserve_order=True, ) # Compute losses output_x = F.softmax(self.model(input_x), 1) loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean() output_u = F.softmax(self.model(input_u), 1) loss_u = ((label_u - output_u)**2).mean() loss = loss_x + loss_u*weight_u self.model_backward_and_update(loss) loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"][0] label_x = batch_x["label"] label_x = create_onehot(label_x, self.num_classes) input_u = batch_u["img"] input_x = input_x.to(self.device) label_x = label_x.to(self.device) input_u = [input_ui.to(self.device) for input_ui in input_u] return input_x, label_x, input_u ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py ================================================ from torch.nn import functional as F from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy @TRAINER_REGISTRY.register() class SupBaseline(TrainerXU): """Supervised Baseline.""" def forward_backward(self, batch_x, batch_u): input, label = self.parse_batch_train(batch_x, batch_u) output = self.model(input) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input = batch_x["img"] label = batch_x["label"] input = input.to(self.device) label = label.to(self.device) return input, label ================================================ FILE: Dassl.ProGrad.pytorch/dassl/engine/trainer.py ================================================ import json import time import numpy as np import os.path as osp import datetime from collections import OrderedDict import torch import torch.nn as nn from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter from dassl.data import DataManager from dassl.optim import build_optimizer, build_lr_scheduler from dassl.utils import ( MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint, save_checkpoint, mkdir_if_missing, resume_from_checkpoint, load_pretrained_weights ) from dassl.modeling import build_head, build_backbone from dassl.evaluation import build_evaluator class SimpleNet(nn.Module): """A simple neural network composed of a CNN backbone and optionally a head such as mlp for classification. """ def __init__(self, cfg, model_cfg, num_classes, **kwargs): super().__init__() self.backbone = build_backbone( model_cfg.BACKBONE.NAME, verbose=cfg.VERBOSE, pretrained=model_cfg.BACKBONE.PRETRAINED, **kwargs, ) fdim = self.backbone.out_features self.head = None if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS: self.head = build_head( model_cfg.HEAD.NAME, verbose=cfg.VERBOSE, in_features=fdim, hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS, activation=model_cfg.HEAD.ACTIVATION, bn=model_cfg.HEAD.BN, dropout=model_cfg.HEAD.DROPOUT, **kwargs, ) fdim = self.head.out_features self.classifier = None if num_classes > 0: self.classifier = nn.Linear(fdim, num_classes) self._fdim = fdim @property def fdim(self): return self._fdim def forward(self, x, return_feature=False): f = self.backbone(x) if self.head is not None: f = self.head(f) if self.classifier is None: return f y = self.classifier(f) if return_feature: return y, f return y class TrainerBase: """Base class for iterative trainer.""" def __init__(self): self._models = OrderedDict() self._optims = OrderedDict() self._scheds = OrderedDict() self._writer = None def register_model(self, name="model", model=None, optim=None, sched=None): if self.__dict__.get("_models") is None: raise AttributeError( "Cannot assign model before super().__init__() call" ) if self.__dict__.get("_optims") is None: raise AttributeError( "Cannot assign optim before super().__init__() call" ) if self.__dict__.get("_scheds") is None: raise AttributeError( "Cannot assign sched before super().__init__() call" ) assert name not in self._models, "Found duplicate model names" self._models[name] = model self._optims[name] = optim self._scheds[name] = sched def get_model_names(self, names=None): names_real = list(self._models.keys()) if names is not None: names = tolist_if_not(names) for name in names: assert name in names_real return names else: return names_real def save_model(self, epoch, directory, is_best=False, model_name=""): names = self.get_model_names() for name in names: model_dict = self._models[name].state_dict() optim_dict = None if self._optims[name] is not None: optim_dict = self._optims[name].state_dict() sched_dict = None if self._scheds[name] is not None: sched_dict = self._scheds[name].state_dict() save_checkpoint( { "state_dict": model_dict, "epoch": epoch + 1, "optimizer": optim_dict, "scheduler": sched_dict, }, osp.join(directory, name), is_best=is_best, model_name=model_name, ) def resume_model_if_exist(self, directory): names = self.get_model_names() file_missing = False for name in names: path = osp.join(directory, name) if not osp.exists(path): file_missing = True break if file_missing: print("No checkpoint found, train from scratch") return 0 print( 'Found checkpoint in "{}". Will resume training'.format(directory) ) for name in names: path = osp.join(directory, name) start_epoch = resume_from_checkpoint( path, self._models[name], self._optims[name], self._scheds[name] ) return start_epoch def load_model(self, directory, epoch=None): if not directory: print( "Note that load_model() is skipped as no pretrained " "model is given (ignore this if it's done on purpose)" ) return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError( 'Model not found at "{}"'.format(model_path) ) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] print( "Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch) ) self._models[name].load_state_dict(state_dict) def set_model_mode(self, mode="train", names=None): names = self.get_model_names(names) for name in names: if mode == "train": self._models[name].train() elif mode in ["test", "eval"]: self._models[name].eval() else: raise KeyError def update_lr(self, names=None): names = self.get_model_names(names) for name in names: if self._scheds[name] is not None: self._scheds[name].step() def detect_anomaly(self, loss): if not torch.isfinite(loss).all(): raise FloatingPointError("Loss is infinite or NaN!") def init_writer(self, log_dir): if self.__dict__.get("_writer") is None or self._writer is None: print( "Initializing summary writer for tensorboard " "with log_dir={}".format(log_dir) ) self._writer = SummaryWriter(log_dir=log_dir) def close_writer(self): if self._writer is not None: self._writer.close() def write_scalar(self, tag, scalar_value, global_step=None): if self._writer is None: # Do nothing if writer is not initialized # Note that writer is only used when training is needed pass else: self._writer.add_scalar(tag, scalar_value, global_step) def train(self, start_epoch, max_epoch): """Generic training loops.""" self.start_epoch = start_epoch self.max_epoch = max_epoch self.before_train() for self.epoch in range(self.start_epoch, self.max_epoch): self.before_epoch() self.run_epoch() self.after_epoch() self.after_train() def before_train(self): pass def after_train(self): pass def before_epoch(self): pass def after_epoch(self): pass def run_epoch(self): raise NotImplementedError def test(self): raise NotImplementedError def parse_batch_train(self, batch): raise NotImplementedError def parse_batch_test(self, batch): raise NotImplementedError def forward_backward(self, batch): raise NotImplementedError def model_inference(self, input): raise NotImplementedError def model_zero_grad(self, names=None): names = self.get_model_names(names) for name in names: if self._optims[name] is not None: self._optims[name].zero_grad() def model_backward(self, loss): self.detect_anomaly(loss) loss.backward() def model_update(self, names=None): names = self.get_model_names(names) for name in names: if self._optims[name] is not None: self._optims[name].step() def model_backward_and_update(self, loss, names=None): self.model_zero_grad(names) self.model_backward(loss) self.model_update(names) def prograd_backward_and_update( self, loss_a, loss_b, lambda_=1, names=None ): # loss_b not increase is okay # loss_a has to decline self.model_zero_grad(names) # get name of the model parameters names = self.get_model_names(names) # backward loss_a self.detect_anomaly(loss_b) loss_b.backward(retain_graph=True) # normalize gradient b_grads = [] for name in names: for p in self._models[name].parameters(): b_grads.append(p.grad.clone()) # optimizer don't step for name in names: self._optims[name].zero_grad() # backward loss_a self.detect_anomaly(loss_a) loss_a.backward() for name in names: for p, b_grad in zip(self._models[name].parameters(), b_grads): # calculate cosine distance b_grad_norm = b_grad / torch.linalg.norm(b_grad) a_grad = p.grad.clone() a_grad_norm = a_grad / torch.linalg.norm(a_grad) if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0: p.grad = a_grad - lambda_ * torch.dot( a_grad.flatten(), b_grad_norm.flatten() ) * b_grad_norm # optimizer for name in names: self._optims[name].step() class SimpleTrainer(TrainerBase): """A simple trainer class implementing generic functions.""" def __init__(self, cfg): super().__init__() self.check_cfg(cfg) if torch.cuda.is_available() and cfg.USE_CUDA: self.device = torch.device("cuda") else: self.device = torch.device("cpu") # Save as attributes some frequently used variables self.start_epoch = self.epoch = 0 self.max_epoch = cfg.OPTIM.MAX_EPOCH self.output_dir = cfg.OUTPUT_DIR self.cfg = cfg self.build_data_loader() self.build_model() self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname) self.best_result = -np.inf def check_cfg(self, cfg): """Check whether some variables are set correctly for the trainer (optional). For example, a trainer might require a particular sampler for training such as 'RandomDomainSampler', so it is good to do the checking: assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler' """ pass def build_data_loader(self): """Create essential data-related attributes. A re-implementation of this method must create the same attributes (except self.dm). """ dm = DataManager(self.cfg) self.train_loader_x = dm.train_loader_x self.train_loader_u = dm.train_loader_u # optional, can be None self.val_loader = dm.val_loader # optional, can be None self.test_loader = dm.test_loader self.num_classes = dm.num_classes self.num_source_domains = dm.num_source_domains self.lab2cname = dm.lab2cname # dict {label: classname} self.dm = dm def build_model(self): """Build and register model. The default builds a classification model along with its optimizer and scheduler. Custom trainers can re-implement this method if necessary. """ cfg = self.cfg print("Building model") self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes) if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) print("# params: {:,}".format(count_num_param(self.model))) self.optim = build_optimizer(self.model, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("model", self.model, self.optim, self.sched) device_count = torch.cuda.device_count() if device_count > 1: print( f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel" ) self.model = nn.DataParallel(self.model) def train(self): super().train(self.start_epoch, self.max_epoch) def before_train(self): directory = self.cfg.OUTPUT_DIR if self.cfg.RESUME: directory = self.cfg.RESUME self.start_epoch = self.resume_model_if_exist(directory) # Initialize summary writer writer_dir = osp.join(self.output_dir, "tensorboard") mkdir_if_missing(writer_dir) self.init_writer(writer_dir) # Remember the starting time (for computing the elapsed time) self.time_start = time.time() def after_train(self): print("Finished training") do_test = not self.cfg.TEST.NO_TEST if do_test: if self.cfg.TEST.FINAL_MODEL == "best_val": print("Deploy the model with the best val performance") self.load_model(self.output_dir) self.test() # Show elapsed time elapsed = round(time.time() - self.time_start) elapsed = str(datetime.timedelta(seconds=elapsed)) print("Elapsed: {}".format(elapsed)) # Close writer self.close_writer() def after_epoch(self): last_epoch = (self.epoch + 1) == self.max_epoch do_test = not self.cfg.TEST.NO_TEST meet_checkpoint_freq = ( (self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0 if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False ) if do_test and self.cfg.TEST.FINAL_MODEL == "best_val": curr_result = self.test(split="val") is_best = curr_result > self.best_result if is_best: self.best_result = curr_result self.save_model( self.epoch, self.output_dir, model_name="model-best.pth.tar" ) if meet_checkpoint_freq or last_epoch: self.save_model(self.epoch, self.output_dir) @torch.no_grad() def output_test(self, split=None): """testing pipline, which could also output the results.""" self.set_model_mode("eval") self.evaluator.reset() output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json') res_json = {} if split is None: split = self.cfg.TEST.SPLIT if split == "val" and self.val_loader is not None: data_loader = self.val_loader print("Do evaluation on {} set".format(split)) else: data_loader = self.test_loader print("Do evaluation on test set") for batch_idx, batch in enumerate(tqdm(data_loader)): img_path = batch['impath'] input, label = self.parse_batch_test(batch) output = self.model_inference(input) self.evaluator.process(output, label) for i in range(len(img_path)): res_json[img_path[i]] = { 'predict': output[i].cpu().numpy().tolist(), 'gt': label[i].cpu().numpy().tolist() } with open(output_file, 'w') as f: json.dump(res_json, f) results = self.evaluator.evaluate() for k, v in results.items(): tag = "{}/{}".format(split, k) self.write_scalar(tag, v, self.epoch) return list(results.values())[0] @torch.no_grad() def test(self, split=None): """A generic testing pipeline.""" self.set_model_mode("eval") self.evaluator.reset() if split is None: split = self.cfg.TEST.SPLIT if split == "val" and self.val_loader is not None: data_loader = self.val_loader print("Do evaluation on {} set".format(split)) else: data_loader = self.test_loader print("Do evaluation on test set") for batch_idx, batch in enumerate(tqdm(data_loader)): input, label = self.parse_batch_test(batch) output = self.model_inference(input) self.evaluator.process(output, label) results = self.evaluator.evaluate() for k, v in results.items(): tag = "{}/{}".format(split, k) self.write_scalar(tag, v, self.epoch) return list(results.values())[0] def model_inference(self, input): return self.model(input) def parse_batch_test(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def get_current_lr(self, names=None): names = self.get_model_names(names) name = names[0] return self._optims[name].param_groups[0]["lr"] class TrainerXU(SimpleTrainer): """A base trainer using both labeled and unlabeled data. In the context of domain adaptation, labeled and unlabeled data come from source and target domains respectively. When it comes to semi-supervised learning, all data comes from the same domain. """ def run_epoch(self): self.set_model_mode("train") losses = MetricMeter() batch_time = AverageMeter() data_time = AverageMeter() # Decide to iterate over labeled or unlabeled dataset len_train_loader_x = len(self.train_loader_x) len_train_loader_u = len(self.train_loader_u) if self.cfg.TRAIN.COUNT_ITER == "train_x": self.num_batches = len_train_loader_x elif self.cfg.TRAIN.COUNT_ITER == "train_u": self.num_batches = len_train_loader_u elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": self.num_batches = min(len_train_loader_x, len_train_loader_u) else: raise ValueError train_loader_x_iter = iter(self.train_loader_x) train_loader_u_iter = iter(self.train_loader_u) end = time.time() for self.batch_idx in range(self.num_batches): try: batch_x = next(train_loader_x_iter) except StopIteration: train_loader_x_iter = iter(self.train_loader_x) batch_x = next(train_loader_x_iter) try: batch_u = next(train_loader_u_iter) except StopIteration: train_loader_u_iter = iter(self.train_loader_u) batch_u = next(train_loader_u_iter) data_time.update(time.time() - end) loss_summary = self.forward_backward(batch_x, batch_u) batch_time.update(time.time() - end) losses.update(loss_summary) if ( self.batch_idx + 1 ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ: nb_remain = 0 nb_remain += self.num_batches - self.batch_idx - 1 nb_remain += ( self.max_epoch - self.epoch - 1 ) * self.num_batches eta_seconds = batch_time.avg * nb_remain eta = str(datetime.timedelta(seconds=int(eta_seconds))) print( "epoch [{0}/{1}][{2}/{3}]\t" "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "data {data_time.val:.3f} ({data_time.avg:.3f})\t" "eta {eta}\t" "{losses}\t" "lr {lr:.6e}".format( self.epoch + 1, self.max_epoch, self.batch_idx + 1, self.num_batches, batch_time=batch_time, data_time=data_time, eta=eta, losses=losses, lr=self.get_current_lr(), ) ) n_iter = self.epoch * self.num_batches + self.batch_idx for name, meter in losses.meters.items(): self.write_scalar("train/" + name, meter.avg, n_iter) self.write_scalar("train/lr", self.get_current_lr(), n_iter) end = time.time() def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"] label_x = batch_x["label"] input_u = batch_u["img"] input_x = input_x.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) return input_x, label_x, input_u class TrainerX(SimpleTrainer): """A base trainer using labeled data only.""" def run_epoch(self): self.set_model_mode("train") losses = MetricMeter() batch_time = AverageMeter() data_time = AverageMeter() self.num_batches = len(self.train_loader_x) end = time.time() for self.batch_idx, batch in enumerate(self.train_loader_x): data_time.update(time.time() - end) loss_summary = self.forward_backward(batch) batch_time.update(time.time() - end) losses.update(loss_summary) if ( self.batch_idx + 1 ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ: nb_remain = 0 nb_remain += self.num_batches - self.batch_idx - 1 nb_remain += ( self.max_epoch - self.epoch - 1 ) * self.num_batches eta_seconds = batch_time.avg * nb_remain eta = str(datetime.timedelta(seconds=int(eta_seconds))) print( "epoch [{0}/{1}][{2}/{3}]\t" "time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "data {data_time.val:.3f} ({data_time.avg:.3f})\t" "eta {eta}\t" "{losses}\t" "lr {lr:.6e}".format( self.epoch + 1, self.max_epoch, self.batch_idx + 1, self.num_batches, batch_time=batch_time, data_time=data_time, eta=eta, losses=losses, lr=self.get_current_lr(), ) ) n_iter = self.epoch * self.num_batches + self.batch_idx for name, meter in losses.meters.items(): self.write_scalar("train/" + name, meter.avg, n_iter) self.write_scalar("train/lr", self.get_current_lr(), n_iter) end = time.time() def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] domain = batch["domain"] input = input.to(self.device) label = label.to(self.device) domain = domain.to(self.device) return input, label, domain ================================================ FILE: Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py ================================================ from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip from .evaluator import EvaluatorBase, Classification ================================================ FILE: Dassl.ProGrad.pytorch/dassl/evaluation/build.py ================================================ from dassl.utils import Registry, check_availability EVALUATOR_REGISTRY = Registry("EVALUATOR") def build_evaluator(cfg, **kwargs): avai_evaluators = EVALUATOR_REGISTRY.registered_names() check_availability(cfg.TEST.EVALUATOR, avai_evaluators) if cfg.VERBOSE: print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR)) return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py ================================================ import numpy as np import os.path as osp from collections import OrderedDict, defaultdict import torch from sklearn.metrics import f1_score, confusion_matrix from .build import EVALUATOR_REGISTRY class EvaluatorBase: """Base evaluator.""" def __init__(self, cfg): self.cfg = cfg def reset(self): raise NotImplementedError def process(self, mo, gt): raise NotImplementedError def evaluate(self): raise NotImplementedError @EVALUATOR_REGISTRY.register() class Classification(EvaluatorBase): """Evaluator for classification.""" def __init__(self, cfg, lab2cname=None, **kwargs): super().__init__(cfg) self._lab2cname = lab2cname self._correct = 0 self._total = 0 self._per_class_res = None self._y_true = [] self._y_pred = [] if cfg.TEST.PER_CLASS_RESULT: assert lab2cname is not None self._per_class_res = defaultdict(list) def reset(self): self._correct = 0 self._total = 0 self._y_true = [] self._y_pred = [] if self._per_class_res is not None: self._per_class_res = defaultdict(list) def process(self, mo, gt): # mo (torch.Tensor): model output [batch, num_classes] # gt (torch.LongTensor): ground truth [batch] pred = mo.max(1)[1] matches = pred.eq(gt).float() self._correct += int(matches.sum().item()) self._total += gt.shape[0] self._y_true.extend(gt.data.cpu().numpy().tolist()) self._y_pred.extend(pred.data.cpu().numpy().tolist()) if self._per_class_res is not None: for i, label in enumerate(gt): label = label.item() matches_i = int(matches[i].item()) self._per_class_res[label].append(matches_i) def evaluate(self): results = OrderedDict() acc = 100.0 * self._correct / self._total err = 100.0 - acc macro_f1 = 100.0 * f1_score( self._y_true, self._y_pred, average="macro", labels=np.unique(self._y_true) ) # The first value will be returned by trainer.test() results["accuracy"] = acc results["error_rate"] = err results["macro_f1"] = macro_f1 print( "=> result\n" f"* total: {self._total:,}\n" f"* correct: {self._correct:,}\n" f"* accuracy: {acc:.2f}%\n" f"* error: {err:.2f}%\n" f"* macro_f1: {macro_f1:.2f}%" ) if self._per_class_res is not None: labels = list(self._per_class_res.keys()) labels.sort() print("=> per-class result") accs = [] for label in labels: classname = self._lab2cname[label] res = self._per_class_res[label] correct = sum(res) total = len(res) acc = 100.0 * correct / total accs.append(acc) print( "* class: {} ({})\t" "total: {:,}\t" "correct: {:,}\t" "acc: {:.2f}%".format( label, classname, total, correct, acc ) ) mean_acc = np.mean(accs) print("* average: {:.2f}%".format(mean_acc)) results["perclass_accuracy"] = mean_acc if self.cfg.TEST.COMPUTE_CMAT: cmat = confusion_matrix( self._y_true, self._y_pred, normalize="true" ) save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt") torch.save(cmat, save_path) print('Confusion matrix is saved to "{}"'.format(save_path)) return results ================================================ FILE: Dassl.ProGrad.pytorch/dassl/metrics/__init__.py ================================================ from .accuracy import compute_accuracy from .distance import ( cosine_distance, compute_distance_matrix, euclidean_squared_distance ) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py ================================================ def compute_accuracy(output, target, topk=(1, )): """Computes the accuracy over the k top predictions for the specified values of k. Args: output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). target (torch.LongTensor): ground truth labels with shape (batch_size). topk (tuple, optional): accuracy at top-k will be computed. For example, topk=(1, 5) means accuracy at top-1 and top-5 will be computed. Returns: list: accuracy at top-k. """ maxk = max(topk) batch_size = target.size(0) if isinstance(output, (tuple, list)): output = output[0] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) acc = correct_k.mul_(100.0 / batch_size) res.append(acc) return res ================================================ FILE: Dassl.ProGrad.pytorch/dassl/metrics/distance.py ================================================ """ Source: https://github.com/KaiyangZhou/deep-person-reid """ import torch from torch.nn import functional as F def compute_distance_matrix(input1, input2, metric="euclidean"): """A wrapper function for computing distance matrix. Each input matrix has the shape (n_data, feature_dim). Args: input1 (torch.Tensor): 2-D feature matrix. input2 (torch.Tensor): 2-D feature matrix. metric (str, optional): "euclidean" or "cosine". Default is "euclidean". Returns: torch.Tensor: distance matrix. """ # check input assert isinstance(input1, torch.Tensor) assert isinstance(input2, torch.Tensor) assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format( input1.dim() ) assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format( input2.dim() ) assert input1.size(1) == input2.size(1) if metric == "euclidean": distmat = euclidean_squared_distance(input1, input2) elif metric == "cosine": distmat = cosine_distance(input1, input2) else: raise ValueError( "Unknown distance metric: {}. " 'Please choose either "euclidean" or "cosine"'.format(metric) ) return distmat def euclidean_squared_distance(input1, input2): """Computes euclidean squared distance. Args: input1 (torch.Tensor): 2-D feature matrix. input2 (torch.Tensor): 2-D feature matrix. Returns: torch.Tensor: distance matrix. """ m, n = input1.size(0), input2.size(0) mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat = mat1 + mat2 distmat.addmm_(1, -2, input1, input2.t()) return distmat def cosine_distance(input1, input2): """Computes cosine distance. Args: input1 (torch.Tensor): 2-D feature matrix. input2 (torch.Tensor): 2-D feature matrix. Returns: torch.Tensor: distance matrix. """ input1_normed = F.normalize(input1, p=2, dim=1) input2_normed = F.normalize(input2, p=2, dim=1) distmat = 1 - torch.mm(input1_normed, input2_normed.t()) return distmat ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/__init__.py ================================================ from .head import HEAD_REGISTRY, build_head from .network import NETWORK_REGISTRY, build_network from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py ================================================ from .build import build_backbone, BACKBONE_REGISTRY # isort:skip from .backbone import Backbone # isort:skip from .vgg import vgg16 from .resnet import ( resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1, resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1, resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123, resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12, resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123, resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123 ) from .alexnet import alexnet from .mobilenetv2 import mobilenetv2 from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2 from .cnn_digitsdg import cnn_digitsdg from .efficientnet import ( efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 ) from .shufflenetv2 import ( shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0 ) from .cnn_digitsingle import cnn_digitsingle from .preact_resnet18 import preact_resnet18 from .cnn_digit5_m3sda import cnn_digit5_m3sda ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py ================================================ import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo from .build import BACKBONE_REGISTRY from .backbone import Backbone model_urls = { "alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth", } class AlexNet(Backbone): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) # Note that self.classifier outputs features rather than logits self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), ) self._out_features = 4096 def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) return self.classifier(x) def init_pretrained_weights(model, model_url): pretrain_dict = model_zoo.load_url(model_url) model.load_state_dict(pretrain_dict, strict=False) @BACKBONE_REGISTRY.register() def alexnet(pretrained=True, **kwargs): model = AlexNet() if pretrained: init_pretrained_weights(model, model_urls["alexnet"]) return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py ================================================ import torch.nn as nn class Backbone(nn.Module): def __init__(self): super().__init__() def forward(self): pass @property def out_features(self): """Output feature dimension.""" if self.__dict__.get("_out_features") is None: return None return self._out_features ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py ================================================ from dassl.utils import Registry, check_availability BACKBONE_REGISTRY = Registry("BACKBONE") def build_backbone(name, verbose=True, **kwargs): avai_backbones = BACKBONE_REGISTRY.registered_names() check_availability(name, avai_backbones) if verbose: print("Backbone: {}".format(name)) return BACKBONE_REGISTRY.get(name)(**kwargs) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py ================================================ """ Reference https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA """ import torch.nn as nn from torch.nn import functional as F from .build import BACKBONE_REGISTRY from .backbone import Backbone class FeatureExtractor(Backbone): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) self.bn1 = nn.BatchNorm2d(64) self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) self.bn2 = nn.BatchNorm2d(64) self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) self.bn3 = nn.BatchNorm2d(128) self.fc1 = nn.Linear(8192, 3072) self.bn1_fc = nn.BatchNorm1d(3072) self.fc2 = nn.Linear(3072, 2048) self.bn2_fc = nn.BatchNorm1d(2048) self._out_features = 2048 def _check_input(self, x): H, W = x.shape[2:] assert ( H == 32 and W == 32 ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) def forward(self, x): self._check_input(x) x = F.relu(self.bn1(self.conv1(x))) x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) x = F.relu(self.bn2(self.conv2(x))) x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) x = F.relu(self.bn3(self.conv3(x))) x = x.view(x.size(0), 8192) x = F.relu(self.bn1_fc(self.fc1(x))) x = F.dropout(x, training=self.training) x = F.relu(self.bn2_fc(self.fc2(x))) return x @BACKBONE_REGISTRY.register() def cnn_digit5_m3sda(**kwargs): """ This architecture was used for the Digit-5 dataset in: - Peng et al. Moment Matching for Multi-Source Domain Adaptation. ICCV 2019. """ return FeatureExtractor() ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsdg.py ================================================ import torch.nn as nn from torch.nn import functional as F from dassl.utils import init_network_weights from .build import BACKBONE_REGISTRY from .backbone import Backbone class Convolution(nn.Module): def __init__(self, c_in, c_out): super().__init__() self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) self.relu = nn.ReLU(True) def forward(self, x): return self.relu(self.conv(x)) class ConvNet(Backbone): def __init__(self, c_hidden=64): super().__init__() self.conv1 = Convolution(3, c_hidden) self.conv2 = Convolution(c_hidden, c_hidden) self.conv3 = Convolution(c_hidden, c_hidden) self.conv4 = Convolution(c_hidden, c_hidden) self._out_features = 2**2 * c_hidden def _check_input(self, x): H, W = x.shape[2:] assert ( H == 32 and W == 32 ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) def forward(self, x): self._check_input(x) x = self.conv1(x) x = F.max_pool2d(x, 2) x = self.conv2(x) x = F.max_pool2d(x, 2) x = self.conv3(x) x = F.max_pool2d(x, 2) x = self.conv4(x) x = F.max_pool2d(x, 2) return x.view(x.size(0), -1) @BACKBONE_REGISTRY.register() def cnn_digitsdg(**kwargs): """ This architecture was used for DigitsDG dataset in: - Zhou et al. Deep Domain-Adversarial Image Generation for Domain Generalisation. AAAI 2020. """ model = ConvNet(c_hidden=64) init_network_weights(model, init_type="kaiming") return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsingle.py ================================================ """ This model is built based on https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py """ import torch.nn as nn from torch.nn import functional as F from dassl.utils import init_network_weights from .build import BACKBONE_REGISTRY from .backbone import Backbone class CNN(Backbone): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, 5) self.conv2 = nn.Conv2d(64, 128, 5) self.fc3 = nn.Linear(5 * 5 * 128, 1024) self.fc4 = nn.Linear(1024, 1024) self._out_features = 1024 def _check_input(self, x): H, W = x.shape[2:] assert ( H == 32 and W == 32 ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) def forward(self, x): self._check_input(x) x = self.conv1(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = self.fc3(x) x = F.relu(x) x = self.fc4(x) x = F.relu(x) return x @BACKBONE_REGISTRY.register() def cnn_digitsingle(**kwargs): model = CNN() init_network_weights(model, init_type="kaiming") return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/__init__.py ================================================ """ Source: https://github.com/lukemelas/EfficientNet-PyTorch. """ __version__ = "0.6.4" from .model import ( EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 ) from .utils import ( BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params ) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/model.py ================================================ import torch from torch import nn from torch.nn import functional as F from .utils import ( Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats, get_model_params, efficientnet_params, get_same_padding_conv2d, load_pretrained_weights, calculate_output_image_size ) from ..build import BACKBONE_REGISTRY from ..backbone import Backbone class MBConvBlock(nn.Module): """ Mobile Inverted Residual Bottleneck Block Args: block_args (namedtuple): BlockArgs, see above global_params (namedtuple): GlobalParam, see above Attributes: has_se (bool): Whether the block contains a Squeeze and Excitation layer. """ def __init__(self, block_args, global_params, image_size=None): super().__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # skip connection and drop connect # Expansion phase inp = self._block_args.input_filters # number of input channels oup = ( self._block_args.input_filters * self._block_args.expand_ratio ) # number of output channels if self._block_args.expand_ratio != 1: Conv2d = get_same_padding_conv2d(image_size=image_size) self._expand_conv = Conv2d( in_channels=inp, out_channels=oup, kernel_size=1, bias=False ) self._bn0 = nn.BatchNorm2d( num_features=oup, momentum=self._bn_mom, eps=self._bn_eps ) # image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride Conv2d = get_same_padding_conv2d(image_size=image_size) self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False, ) self._bn1 = nn.BatchNorm2d( num_features=oup, momentum=self._bn_mom, eps=self._bn_eps ) image_size = calculate_output_image_size(image_size, s) # Squeeze and Excitation layer, if desired if self.has_se: Conv2d = get_same_padding_conv2d(image_size=(1, 1)) num_squeezed_channels = max( 1, int( self._block_args.input_filters * self._block_args.se_ratio ) ) self._se_reduce = Conv2d( in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 ) self._se_expand = Conv2d( in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 ) # Output phase final_oup = self._block_args.output_filters Conv2d = get_same_padding_conv2d(image_size=image_size) self._project_conv = Conv2d( in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False ) self._bn2 = nn.BatchNorm2d( num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps ) self._swish = MemoryEfficientSwish() def forward(self, inputs, drop_connect_rate=None): """ :param inputs: input tensor :param drop_connect_rate: drop connect rate (float, between 0 and 1) :return: output of block """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._swish(self._bn0(self._expand_conv(inputs))) x = self._swish(self._bn1(self._depthwise_conv(x))) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_expand( self._swish(self._se_reduce(x_squeezed)) ) x = torch.sigmoid(x_squeezed) * x x = self._bn2(self._project_conv(x)) # Skip connection and drop connect input_filters, output_filters = ( self._block_args.input_filters, self._block_args.output_filters, ) if ( self.id_skip and self._block_args.stride == 1 and input_filters == output_filters ): if drop_connect_rate: x = drop_connect( x, p=drop_connect_rate, training=self.training ) x = x + inputs # skip connection return x def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() class EfficientNet(Backbone): """ An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods Args: blocks_args (list): A list of BlockArgs to construct blocks global_params (namedtuple): A set of GlobalParams shared between blocks Example: model = EfficientNet.from_pretrained('efficientnet-b0') """ def __init__(self, blocks_args=None, global_params=None): super().__init__() assert isinstance(blocks_args, list), "blocks_args should be a list" assert len(blocks_args) > 0, "block args must be greater than 0" self._global_params = global_params self._blocks_args = blocks_args # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Get stem static or dynamic convolution depending on image size image_size = global_params.image_size Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) # Stem in_channels = 3 # rgb out_channels = round_filters( 32, self._global_params ) # number of output channels self._conv_stem = Conv2d( in_channels, out_channels, kernel_size=3, stride=2, bias=False ) self._bn0 = nn.BatchNorm2d( num_features=out_channels, momentum=bn_mom, eps=bn_eps ) image_size = calculate_output_image_size(image_size, 2) # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters( block_args.input_filters, self._global_params ), output_filters=round_filters( block_args.output_filters, self._global_params ), num_repeat=round_repeats( block_args.num_repeat, self._global_params ), ) # The first block needs to take care of stride and filter size increase. self._blocks.append( MBConvBlock( block_args, self._global_params, image_size=image_size ) ) image_size = calculate_output_image_size( image_size, block_args.stride ) if block_args.num_repeat > 1: block_args = block_args._replace( input_filters=block_args.output_filters, stride=1 ) for _ in range(block_args.num_repeat - 1): self._blocks.append( MBConvBlock( block_args, self._global_params, image_size=image_size ) ) # image_size = calculate_output_image_size(image_size, block_args.stride) # ? # Head in_channels = block_args.output_filters # output of final block out_channels = round_filters(1280, self._global_params) Conv2d = get_same_padding_conv2d(image_size=image_size) self._conv_head = Conv2d( in_channels, out_channels, kernel_size=1, bias=False ) self._bn1 = nn.BatchNorm2d( num_features=out_channels, momentum=bn_mom, eps=bn_eps ) # Final linear layer self._avg_pooling = nn.AdaptiveAvgPool2d(1) self._dropout = nn.Dropout(self._global_params.dropout_rate) # self._fc = nn.Linear(out_channels, self._global_params.num_classes) self._swish = MemoryEfficientSwish() self._out_features = out_channels def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() for block in self._blocks: block.set_swish(memory_efficient) def extract_features(self, inputs): """Returns output of the final convolution layer""" # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) x = block(x, drop_connect_rate=drop_connect_rate) # Head x = self._swish(self._bn1(self._conv_head(x))) return x def forward(self, inputs): """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ bs = inputs.size(0) # Convolution layers x = self.extract_features(inputs) # Pooling and final linear layer x = self._avg_pooling(x) x = x.view(bs, -1) x = self._dropout(x) # x = self._fc(x) return x @classmethod def from_name(cls, model_name, override_params=None): cls._check_model_name_is_valid(model_name) blocks_args, global_params = get_model_params( model_name, override_params ) return cls(blocks_args, global_params) @classmethod def from_pretrained( cls, model_name, advprop=False, num_classes=1000, in_channels=3 ): model = cls.from_name( model_name, override_params={"num_classes": num_classes} ) load_pretrained_weights( model, model_name, load_fc=(num_classes == 1000), advprop=advprop ) model._change_in_channels(in_channels) return model @classmethod def get_image_size(cls, model_name): cls._check_model_name_is_valid(model_name) _, _, res, _ = efficientnet_params(model_name) return res @classmethod def _check_model_name_is_valid(cls, model_name): """Validates model name.""" valid_models = ["efficientnet-b" + str(i) for i in range(9)] if model_name not in valid_models: raise ValueError( "model_name should be one of: " + ", ".join(valid_models) ) def _change_in_channels(model, in_channels): if in_channels != 3: Conv2d = get_same_padding_conv2d( image_size=model._global_params.image_size ) out_channels = round_filters(32, model._global_params) model._conv_stem = Conv2d( in_channels, out_channels, kernel_size=3, stride=2, bias=False ) def build_efficientnet(name, pretrained): if pretrained: return EfficientNet.from_pretrained("efficientnet-{}".format(name)) else: return EfficientNet.from_name("efficientnet-{}".format(name)) @BACKBONE_REGISTRY.register() def efficientnet_b0(pretrained=True, **kwargs): return build_efficientnet("b0", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b1(pretrained=True, **kwargs): return build_efficientnet("b1", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b2(pretrained=True, **kwargs): return build_efficientnet("b2", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b3(pretrained=True, **kwargs): return build_efficientnet("b3", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b4(pretrained=True, **kwargs): return build_efficientnet("b4", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b5(pretrained=True, **kwargs): return build_efficientnet("b5", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b6(pretrained=True, **kwargs): return build_efficientnet("b6", pretrained) @BACKBONE_REGISTRY.register() def efficientnet_b7(pretrained=True, **kwargs): return build_efficientnet("b7", pretrained) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/utils.py ================================================ """ This file contains helper functions for building the model and for loading model parameters. These helper functions are built to mirror those in the official TensorFlow implementation. """ import re import math import collections from functools import partial import torch from torch import nn from torch.nn import functional as F from torch.utils import model_zoo ######################################################################## ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### ######################################################################## # Parameters for the entire model (stem, all blocks, and head) GlobalParams = collections.namedtuple( "GlobalParams", [ "batch_norm_momentum", "batch_norm_epsilon", "dropout_rate", "num_classes", "width_coefficient", "depth_coefficient", "depth_divisor", "min_depth", "drop_connect_rate", "image_size", ], ) # Parameters for an individual model block BlockArgs = collections.namedtuple( "BlockArgs", [ "kernel_size", "num_repeat", "input_filters", "output_filters", "expand_ratio", "id_skip", "stride", "se_ratio", ], ) # Change namedtuple defaults GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): i = ctx.saved_variables[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): return SwishImplementation.apply(x) class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) def round_filters(filters, global_params): """Calculate and round number of filters based on depth multiplier.""" multiplier = global_params.width_coefficient if not multiplier: return filters divisor = global_params.depth_divisor min_depth = global_params.min_depth filters *= multiplier min_depth = min_depth or divisor new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) def round_repeats(repeats, global_params): """Round number of filters based on depth multiplier.""" multiplier = global_params.depth_coefficient if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) def drop_connect(inputs, p, training): """Drop connect.""" if not training: return inputs batch_size = inputs.shape[0] keep_prob = 1 - p random_tensor = keep_prob random_tensor += torch.rand( [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device ) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output def get_same_padding_conv2d(image_size=None): """Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models.""" if image_size is None: return Conv2dDynamicSamePadding else: return partial(Conv2dStaticSamePadding, image_size=image_size) def get_width_and_height_from_size(x): """Obtains width and height from a int or tuple""" if isinstance(x, int): return x, x if isinstance(x, list) or isinstance(x, tuple): return x else: raise TypeError() def calculate_output_image_size(input_image_size, stride): """ Calculates the output image size when using Conv2dSamePadding with a stride. Necessary for static padding. Thanks to mannatsingh for pointing this out. """ if input_image_size is None: return None image_height, image_width = get_width_and_height_from_size( input_image_size ) stride = stride if isinstance(stride, int) else stride[0] image_height = int(math.ceil(image_height / stride)) image_width = int(math.ceil(image_width / stride)) return [image_height, image_width] class Conv2dDynamicSamePadding(nn.Conv2d): """2D Convolutions like TensorFlow, for a dynamic image size""" def __init__( self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, ): super().__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias ) self.stride = self.stride if len(self.stride ) == 2 else [self.stride[0]] * 2 def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max( (oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0 ) pad_w = max( (ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0 ) if pad_h > 0 or pad_w > 0: x = F.pad( x, [pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2] ) return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) class Conv2dStaticSamePadding(nn.Conv2d): """2D Convolutions like TensorFlow, for a fixed image size""" def __init__( self, in_channels, out_channels, kernel_size, image_size=None, **kwargs ): super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.stride = self.stride if len(self.stride ) == 2 else [self.stride[0]] * 2 # Calculate padding based on image size and save it assert image_size is not None ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max( (oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0 ) pad_w = max( (ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0 ) if pad_h > 0 or pad_w > 0: self.static_padding = nn.ZeroPad2d( (pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2) ) else: self.static_padding = Identity() def forward(self, x): x = self.static_padding(x) x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) return x class Identity(nn.Module): def __init__(self, ): super(Identity, self).__init__() def forward(self, input): return input ######################################################################## ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## ######################################################################## def efficientnet_params(model_name): """Map EfficientNet model name to parameter coefficients.""" params_dict = { # Coefficients: width,depth,res,dropout "efficientnet-b0": (1.0, 1.0, 224, 0.2), "efficientnet-b1": (1.0, 1.1, 240, 0.2), "efficientnet-b2": (1.1, 1.2, 260, 0.3), "efficientnet-b3": (1.2, 1.4, 300, 0.3), "efficientnet-b4": (1.4, 1.8, 380, 0.4), "efficientnet-b5": (1.6, 2.2, 456, 0.4), "efficientnet-b6": (1.8, 2.6, 528, 0.5), "efficientnet-b7": (2.0, 3.1, 600, 0.5), "efficientnet-b8": (2.2, 3.6, 672, 0.5), "efficientnet-l2": (4.3, 5.3, 800, 0.5), } return params_dict[model_name] class BlockDecoder(object): """Block Decoder for readability, straight from the official TensorFlow repository""" @staticmethod def _decode_block_string(block_string): """Gets a block through a string notation of arguments.""" assert isinstance(block_string, str) ops = block_string.split("_") options = {} for op in ops: splits = re.split(r"(\d.*)", op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # Check stride assert ("s" in options and len(options["s"]) == 1) or ( len(options["s"]) == 2 and options["s"][0] == options["s"][1] ) return BlockArgs( kernel_size=int(options["k"]), num_repeat=int(options["r"]), input_filters=int(options["i"]), output_filters=int(options["o"]), expand_ratio=int(options["e"]), id_skip=("noskip" not in block_string), se_ratio=float(options["se"]) if "se" in options else None, stride=[int(options["s"][0])], ) @staticmethod def _encode_block_string(block): """Encodes a block to a string.""" args = [ "r%d" % block.num_repeat, "k%d" % block.kernel_size, "s%d%d" % (block.strides[0], block.strides[1]), "e%s" % block.expand_ratio, "i%d" % block.input_filters, "o%d" % block.output_filters, ] if 0 < block.se_ratio <= 1: args.append("se%s" % block.se_ratio) if block.id_skip is False: args.append("noskip") return "_".join(args) @staticmethod def decode(string_list): """ Decodes a list of string notations to specify blocks inside the network. :param string_list: a list of strings, each string is a notation of block :return: a list of BlockArgs namedtuples of block args """ assert isinstance(string_list, list) blocks_args = [] for block_string in string_list: blocks_args.append(BlockDecoder._decode_block_string(block_string)) return blocks_args @staticmethod def encode(blocks_args): """ Encodes a list of BlockArgs to a list of strings. :param blocks_args: a list of BlockArgs namedtuples of block args :return: a list of strings, each string is a notation of block """ block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) return block_strings def efficientnet( width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, drop_connect_rate=0.2, image_size=None, num_classes=1000, ): """Creates a efficientnet model.""" blocks_args = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", "r2_k5_s22_e6_i24_o40_se0.25", "r3_k3_s22_e6_i40_o80_se0.25", "r3_k5_s11_e6_i80_o112_se0.25", "r4_k5_s22_e6_i112_o192_se0.25", "r1_k3_s11_e6_i192_o320_se0.25", ] blocks_args = BlockDecoder.decode(blocks_args) global_params = GlobalParams( batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, dropout_rate=dropout_rate, drop_connect_rate=drop_connect_rate, # data_format='channels_last', # removed, this is always true in PyTorch num_classes=num_classes, width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, depth_divisor=8, min_depth=None, image_size=image_size, ) return blocks_args, global_params def get_model_params(model_name, override_params): """Get the block args and global params for a given model""" if model_name.startswith("efficientnet"): w, d, s, p = efficientnet_params(model_name) # note: all models have drop connect rate = 0.2 blocks_args, global_params = efficientnet( width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s ) else: raise NotImplementedError( "model name is not pre-defined: %s" % model_name ) if override_params: # ValueError will be raised here if override_params has fields not included in global_params. global_params = global_params._replace(**override_params) return blocks_args, global_params url_map = { "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", } url_map_advprop = { "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", "efficientnet-b8": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", } def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): """Loads pretrained weights, and downloads if loading for the first time.""" # AutoAugment or Advprop (different preprocessing) url_map_ = url_map_advprop if advprop else url_map state_dict = model_zoo.load_url(url_map_[model_name]) model.load_state_dict(state_dict, strict=False) """ if load_fc: model.load_state_dict(state_dict) else: state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') res = model.load_state_dict(state_dict, strict=False) assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' print('Loaded pretrained weights for {}'.format(model_name)) """ ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py ================================================ import torch.utils.model_zoo as model_zoo from torch import nn from .build import BACKBONE_REGISTRY from .backbone import Backbone model_urls = { "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", } def _make_divisible(v, divisor, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_value: :return: """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor/2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v class ConvBNReLU(nn.Sequential): def __init__( self, in_planes, out_planes, kernel_size=3, stride=1, groups=1 ): padding = (kernel_size-1) // 2 super().__init__( nn.Conv2d( in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False, ), nn.BatchNorm2d(out_planes), nn.ReLU6(inplace=True), ) class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(round(inp * expand_ratio)) self.use_res_connect = self.stride == 1 and inp == oup layers = [] if expand_ratio != 1: # pw layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) layers.extend( [ # dw ConvBNReLU( hidden_dim, hidden_dim, stride=stride, groups=hidden_dim ), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ] ) self.conv = nn.Sequential(*layers) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV2(Backbone): def __init__( self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None, ): """ MobileNet V2. Args: width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount inverted_residual_setting: Network structure round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for mobilenet """ super().__init__() if block is None: block = InvertedResidual input_channel = 32 last_channel = 1280 if inverted_residual_setting is None: inverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1], ] # only check the first element, assuming user knows t,c,n,s are required if ( len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4 ): raise ValueError( "inverted_residual_setting should be non-empty " "or a 4-element list, got {}". format(inverted_residual_setting) ) # building first layer input_channel = _make_divisible( input_channel * width_mult, round_nearest ) self.last_channel = _make_divisible( last_channel * max(1.0, width_mult), round_nearest ) features = [ConvBNReLU(3, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) for i in range(n): stride = s if i == 0 else 1 features.append( block( input_channel, output_channel, stride, expand_ratio=t ) ) input_channel = output_channel # building last several layers features.append( ConvBNReLU(input_channel, self.last_channel, kernel_size=1) ) # make it nn.Sequential self.features = nn.Sequential(*features) self._out_features = self.last_channel # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) def _forward_impl(self, x): # This exists since TorchScript doesn't support inheritance, so the superclass method # (this one) needs to have a name other than `forward` that can be accessed in a subclass x = self.features(x) x = x.mean([2, 3]) return x def forward(self, x): return self._forward_impl(x) def init_pretrained_weights(model, model_url): """Initializes model with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ if model_url is None: import warnings warnings.warn( "ImageNet pretrained weights are unavailable for this model" ) return pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) @BACKBONE_REGISTRY.register() def mobilenetv2(pretrained=True, **kwargs): model = MobileNetV2(**kwargs) if pretrained: init_pretrained_weights(model, model_urls["mobilenet_v2"]) return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py ================================================ import torch.nn as nn import torch.nn.functional as F from .build import BACKBONE_REGISTRY from .backbone import Backbone class PreActBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super().__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, ) ) def forward(self, x): out = F.relu(self.bn1(x)) shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x out = self.conv1(out) out = self.conv2(F.relu(self.bn2(out))) out += shortcut return out class PreActBottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super().__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn3 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d( planes, self.expansion * planes, kernel_size=1, bias=False ) if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, ) ) def forward(self, x): out = F.relu(self.bn1(x)) shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x out = self.conv1(out) out = self.conv2(F.relu(self.bn2(out))) out = self.conv3(F.relu(self.bn3(out))) out += shortcut return out class PreActResNet(Backbone): def __init__(self, block, num_blocks): super().__init__() self.in_planes = 64 self.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=1, bias=False ) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self._out_features = 512 * block.expansion def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) return out """ Preact-ResNet18 was used for the CIFAR10 and SVHN datasets (both are SSL tasks) in - Wang et al. Semi-Supervised Learning by Augmented Distribution Alignment. ICCV 2019. """ @BACKBONE_REGISTRY.register() def preact_resnet18(**kwargs): return PreActResNet(PreActBlock, [2, 2, 2, 2]) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py ================================================ import torch.nn as nn import torch.utils.model_zoo as model_zoo from .build import BACKBONE_REGISTRY from .backbone import Backbone model_urls = { "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", } def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False ) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d( planes, planes * self.expansion, kernel_size=1, bias=False ) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(Backbone): def __init__( self, block, layers, ms_class=None, ms_layers=[], ms_p=0.5, ms_a=0.1, **kwargs ): self.inplanes = 64 super().__init__() # backbone network self.conv1 = nn.Conv2d( 3, 64, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self._out_features = 512 * block.expansion self.mixstyle = None if ms_layers: self.mixstyle = ms_class(p=ms_p, alpha=ms_a) for layer_name in ms_layers: assert layer_name in ["layer1", "layer2", "layer3"] print(f"Insert MixStyle after {ms_layers}") self.ms_layers = ms_layers self._init_params() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def featuremaps(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) if "layer1" in self.ms_layers: x = self.mixstyle(x) x = self.layer2(x) if "layer2" in self.ms_layers: x = self.mixstyle(x) x = self.layer3(x) if "layer3" in self.ms_layers: x = self.mixstyle(x) return self.layer4(x) def forward(self, x): f = self.featuremaps(x) v = self.global_avgpool(f) return v.view(v.size(0), -1) def init_pretrained_weights(model, model_url): pretrain_dict = model_zoo.load_url(model_url) model.load_state_dict(pretrain_dict, strict=False) """ Residual network configurations: -- resnet18: block=BasicBlock, layers=[2, 2, 2, 2] resnet34: block=BasicBlock, layers=[3, 4, 6, 3] resnet50: block=Bottleneck, layers=[3, 4, 6, 3] resnet101: block=Bottleneck, layers=[3, 4, 23, 3] resnet152: block=Bottleneck, layers=[3, 8, 36, 3] """ @BACKBONE_REGISTRY.register() def resnet18(pretrained=True, **kwargs): model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2]) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet34(pretrained=True, **kwargs): model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3]) if pretrained: init_pretrained_weights(model, model_urls["resnet34"]) return model @BACKBONE_REGISTRY.register() def resnet50(pretrained=True, **kwargs): model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3]) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet101(pretrained=True, **kwargs): model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3]) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model @BACKBONE_REGISTRY.register() def resnet152(pretrained=True, **kwargs): model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3]) if pretrained: init_pretrained_weights(model, model_urls["resnet152"]) return model """ Residual networks with mixstyle """ @BACKBONE_REGISTRY.register() def resnet18_ms_l123(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=MixStyle, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet18_ms_l12(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=MixStyle, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet18_ms_l1(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=MixStyle, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet50_ms_l123(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=MixStyle, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet50_ms_l12(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=MixStyle, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet50_ms_l1(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=MixStyle, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet101_ms_l123(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=MixStyle, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model @BACKBONE_REGISTRY.register() def resnet101_ms_l12(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=MixStyle, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model @BACKBONE_REGISTRY.register() def resnet101_ms_l1(pretrained=True, **kwargs): from dassl.modeling.ops import MixStyle model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=MixStyle, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model """ Residual networks with efdmix """ @BACKBONE_REGISTRY.register() def resnet18_efdmix_l123(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=EFDMix, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet18_efdmix_l12(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=EFDMix, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet18_efdmix_l1(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=BasicBlock, layers=[2, 2, 2, 2], ms_class=EFDMix, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet18"]) return model @BACKBONE_REGISTRY.register() def resnet50_efdmix_l123(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=EFDMix, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet50_efdmix_l12(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=EFDMix, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet50_efdmix_l1(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 6, 3], ms_class=EFDMix, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet50"]) return model @BACKBONE_REGISTRY.register() def resnet101_efdmix_l123(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=EFDMix, ms_layers=["layer1", "layer2", "layer3"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model @BACKBONE_REGISTRY.register() def resnet101_efdmix_l12(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=EFDMix, ms_layers=["layer1", "layer2"], ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model @BACKBONE_REGISTRY.register() def resnet101_efdmix_l1(pretrained=True, **kwargs): from dassl.modeling.ops import EFDMix model = ResNet( block=Bottleneck, layers=[3, 4, 23, 3], ms_class=EFDMix, ms_layers=["layer1"] ) if pretrained: init_pretrained_weights(model, model_urls["resnet101"]) return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py ================================================ """ Code source: https://github.com/pytorch/vision """ import torch import torch.utils.model_zoo as model_zoo from torch import nn from .build import BACKBONE_REGISTRY from .backbone import Backbone model_urls = { "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", "shufflenetv2_x1.5": None, "shufflenetv2_x2.0": None, } def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride): super().__init__() if not (1 <= stride <= 3): raise ValueError("illegal stride value") self.stride = stride branch_features = oup // 2 assert (self.stride != 1) or (inp == branch_features << 1) if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv( inp, inp, kernel_size=3, stride=self.stride, padding=1 ), nn.BatchNorm2d(inp), nn.Conv2d( inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d( inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv( branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, ), nn.BatchNorm2d(branch_features), nn.Conv2d( branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) @staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): return nn.Conv2d( i, o, kernel_size, stride, padding, bias=bias, groups=i ) def forward(self, x): if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out class ShuffleNetV2(Backbone): def __init__(self, stages_repeats, stages_out_channels, **kwargs): super().__init__() if len(stages_repeats) != 3: raise ValueError( "expected stages_repeats as list of 3 positive ints" ) if len(stages_out_channels) != 5: raise ValueError( "expected stages_out_channels as list of 5 positive ints" ) self._stage_out_channels = stages_out_channels input_channels = 3 output_channels = self._stage_out_channels[0] self.conv1 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) input_channels = output_channels self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) stage_names = ["stage{}".format(i) for i in [2, 3, 4]] for name, repeats, output_channels in zip( stage_names, stages_repeats, self._stage_out_channels[1:] ): seq = [InvertedResidual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append( InvertedResidual(output_channels, output_channels, 1) ) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels output_channels = self._stage_out_channels[-1] self.conv5 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) self._out_features = output_channels def featuremaps(self, x): x = self.conv1(x) x = self.maxpool(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.conv5(x) return x def forward(self, x): f = self.featuremaps(x) v = self.global_avgpool(f) return v.view(v.size(0), -1) def init_pretrained_weights(model, model_url): """Initializes model with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ if model_url is None: import warnings warnings.warn( "ImageNet pretrained weights are unavailable for this model" ) return pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) @BACKBONE_REGISTRY.register() def shufflenet_v2_x0_5(pretrained=True, **kwargs): model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) if pretrained: init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"]) return model @BACKBONE_REGISTRY.register() def shufflenet_v2_x1_0(pretrained=True, **kwargs): model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) if pretrained: init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"]) return model @BACKBONE_REGISTRY.register() def shufflenet_v2_x1_5(pretrained=True, **kwargs): model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) if pretrained: init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"]) return model @BACKBONE_REGISTRY.register() def shufflenet_v2_x2_0(pretrained=True, **kwargs): model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) if pretrained: init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"]) return model ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py ================================================ import torch import torch.nn as nn from .build import BACKBONE_REGISTRY from .backbone import Backbone try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url model_urls = { "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", } class VGG(Backbone): def __init__(self, features, init_weights=True): super().__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # Note that self.classifier outputs features rather than logits self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), ) self._out_features = 4096 if init_weights: self._initialize_weights() def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) return self.classifier(x) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def make_layers(cfg, batch_norm=False): layers = [] in_channels = 3 for v in cfg: if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) cfgs = { "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "D": [ 64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M", ], "E": [ 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M", ], } def _vgg(arch, cfg, batch_norm, pretrained): init_weights = False if pretrained else True model = VGG( make_layers(cfgs[cfg], batch_norm=batch_norm), init_weights=init_weights ) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=True) model.load_state_dict(state_dict, strict=False) return model @BACKBONE_REGISTRY.register() def vgg16(pretrained=True, **kwargs): return _vgg("vgg16", "D", False, pretrained) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py ================================================ """ Modified from https://github.com/xternalz/WideResNet-pytorch """ import torch import torch.nn as nn import torch.nn.functional as F from .build import BACKBONE_REGISTRY from .backbone import Backbone class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate=0.0): super().__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.relu1 = nn.LeakyReLU(0.01, inplace=True) self.conv1 = nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_planes) self.relu2 = nn.LeakyReLU(0.01, inplace=True) self.conv2 = nn.Conv2d( out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False ) self.droprate = dropRate self.equalInOut = in_planes == out_planes self.convShortcut = ( (not self.equalInOut) and nn.Conv2d( in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False, ) or None ) def forward(self, x): if not self.equalInOut: x = self.relu1(self.bn1(x)) else: out = self.relu1(self.bn1(x)) out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) if self.droprate > 0: out = F.dropout(out, p=self.droprate, training=self.training) out = self.conv2(out) return torch.add(x if self.equalInOut else self.convShortcut(x), out) class NetworkBlock(nn.Module): def __init__( self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0 ): super().__init__() self.layer = self._make_layer( block, in_planes, out_planes, nb_layers, stride, dropRate ) def _make_layer( self, block, in_planes, out_planes, nb_layers, stride, dropRate ): layers = [] for i in range(int(nb_layers)): layers.append( block( i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, ) ) return nn.Sequential(*layers) def forward(self, x): return self.layer(x) class WideResNet(Backbone): def __init__(self, depth, widen_factor, dropRate=0.0): super().__init__() nChannels = [ 16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor ] assert (depth-4) % 6 == 0 n = (depth-4) / 6 block = BasicBlock # 1st conv before any network block self.conv1 = nn.Conv2d( 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False ) # 1st block self.block1 = NetworkBlock( n, nChannels[0], nChannels[1], block, 1, dropRate ) # 2nd block self.block2 = NetworkBlock( n, nChannels[1], nChannels[2], block, 2, dropRate ) # 3rd block self.block3 = NetworkBlock( n, nChannels[2], nChannels[3], block, 2, dropRate ) # global average pooling and classifier self.bn1 = nn.BatchNorm2d(nChannels[3]) self.relu = nn.LeakyReLU(0.01, inplace=True) self._out_features = nChannels[3] for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): out = self.conv1(x) out = self.block1(out) out = self.block2(out) out = self.block3(out) out = self.relu(self.bn1(out)) out = F.adaptive_avg_pool2d(out, 1) return out.view(out.size(0), -1) @BACKBONE_REGISTRY.register() def wide_resnet_28_2(**kwargs): return WideResNet(28, 2) @BACKBONE_REGISTRY.register() def wide_resnet_16_4(**kwargs): return WideResNet(16, 4) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py ================================================ from .build import build_head, HEAD_REGISTRY # isort:skip from .mlp import mlp ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/build.py ================================================ from dassl.utils import Registry, check_availability HEAD_REGISTRY = Registry("HEAD") def build_head(name, verbose=True, **kwargs): avai_heads = HEAD_REGISTRY.registered_names() check_availability(name, avai_heads) if verbose: print("Head: {}".format(name)) return HEAD_REGISTRY.get(name)(**kwargs) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py ================================================ import functools import torch.nn as nn from .build import HEAD_REGISTRY class MLP(nn.Module): def __init__( self, in_features=2048, hidden_layers=[], activation="relu", bn=True, dropout=0.0, ): super().__init__() if isinstance(hidden_layers, int): hidden_layers = [hidden_layers] assert len(hidden_layers) > 0 self.out_features = hidden_layers[-1] mlp = [] if activation == "relu": act_fn = functools.partial(nn.ReLU, inplace=True) elif activation == "leaky_relu": act_fn = functools.partial(nn.LeakyReLU, inplace=True) else: raise NotImplementedError for hidden_dim in hidden_layers: mlp += [nn.Linear(in_features, hidden_dim)] if bn: mlp += [nn.BatchNorm1d(hidden_dim)] mlp += [act_fn()] if dropout > 0: mlp += [nn.Dropout(dropout)] in_features = hidden_dim self.mlp = nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) @HEAD_REGISTRY.register() def mlp(**kwargs): return MLP(**kwargs) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py ================================================ from .build import build_network, NETWORK_REGISTRY # isort:skip from .ddaig_fcn import ( fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn ) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/build.py ================================================ from dassl.utils import Registry, check_availability NETWORK_REGISTRY = Registry("NETWORK") def build_network(name, verbose=True, **kwargs): avai_models = NETWORK_REGISTRY.registered_names() check_availability(name, avai_models) if verbose: print("Network: {}".format(name)) return NETWORK_REGISTRY.get(name)(**kwargs) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py ================================================ """ Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix """ import functools import torch import torch.nn as nn from torch.nn import functional as F from .build import NETWORK_REGISTRY def init_network_weights(model, init_type="normal", gain=0.02): def _init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and ( classname.find("Conv") != -1 or classname.find("Linear") != -1 ): if init_type == "normal": nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == "xavier": nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == "kaiming": nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": nn.init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError( "initialization method {} is not implemented". format(init_type) ) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm2d") != -1: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) elif classname.find("InstanceNorm2d") != -1: if m.weight is not None and m.bias is not None: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) model.apply(_init_func) def get_norm_layer(norm_type="instance"): if norm_type == "batch": norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == "instance": norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False ) elif norm_type == "none": norm_layer = None else: raise NotImplementedError( "normalization layer [%s] is not found" % norm_type ) return norm_layer class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super().__init__() self.conv_block = self.build_conv_block( dim, padding_type, norm_layer, use_dropout, use_bias ) def build_conv_block( self, dim, padding_type, norm_layer, use_dropout, use_bias ): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError( "padding [%s] is not implemented" % padding_type ) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True), ] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError( "padding [%s] is not implemented" % padding_type ) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), ] return nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class LocNet(nn.Module): """Localization network.""" def __init__( self, input_nc, nc=32, n_blocks=3, use_dropout=False, padding_type="zero", image_size=32, ): super().__init__() backbone = [] backbone += [ nn.Conv2d( input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False ) ] backbone += [nn.BatchNorm2d(nc)] backbone += [nn.ReLU(True)] for _ in range(n_blocks): backbone += [ ResnetBlock( nc, padding_type=padding_type, norm_layer=nn.BatchNorm2d, use_dropout=use_dropout, use_bias=False, ) ] backbone += [nn.MaxPool2d(2, stride=2)] self.backbone = nn.Sequential(*backbone) reduced_imsize = int(image_size * 0.5**(n_blocks + 1)) self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2) def forward(self, x): x = self.backbone(x) x = x.view(x.size(0), -1) x = self.fc_loc(x) x = torch.tanh(x) x = x.view(-1, 2, 2) theta = x.data.new_zeros(x.size(0), 2, 3) theta[:, :, :2] = x return theta class FCN(nn.Module): """Fully convolutional network.""" def __init__( self, input_nc, output_nc, nc=32, n_blocks=3, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type="reflect", gctx=True, stn=False, image_size=32, ): super().__init__() backbone = [] p = 0 if padding_type == "reflect": backbone += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": backbone += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError backbone += [ nn.Conv2d( input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False ) ] backbone += [norm_layer(nc)] backbone += [nn.ReLU(True)] for _ in range(n_blocks): backbone += [ ResnetBlock( nc, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=False, ) ] self.backbone = nn.Sequential(*backbone) # global context fusion layer self.gctx_fusion = None if gctx: self.gctx_fusion = nn.Sequential( nn.Conv2d( 2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False ), norm_layer(nc), nn.ReLU(True), ) self.regress = nn.Sequential( nn.Conv2d( nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True ), nn.Tanh(), ) self.locnet = None if stn: self.locnet = LocNet( input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size ) def init_loc_layer(self): """Initialize the weights/bias with identity transformation.""" if self.locnet is not None: self.locnet.fc_loc.weight.data.zero_() self.locnet.fc_loc.bias.data.copy_( torch.tensor([1, 0, 0, 1], dtype=torch.float) ) def stn(self, x): """Spatial transformer network.""" theta = self.locnet(x) grid = F.affine_grid(theta, x.size()) return F.grid_sample(x, grid), theta def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False): """ Args: x (torch.Tensor): input mini-batch. lmda (float): multiplier for perturbation. return_p (bool): return perturbation. return_stn_output (bool): return the output of stn. """ theta = None if self.locnet is not None: x, theta = self.stn(x) input = x x = self.backbone(x) if self.gctx_fusion is not None: c = F.adaptive_avg_pool2d(x, (1, 1)) c = c.expand_as(x) x = torch.cat([x, c], 1) x = self.gctx_fusion(x) p = self.regress(x) x_p = input + lmda*p if return_stn_output: return x_p, p, input if return_p: return x_p, p return x_p @NETWORK_REGISTRY.register() def fcn_3x32_gctx(**kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer) init_network_weights(net, init_type="normal", gain=0.02) return net @NETWORK_REGISTRY.register() def fcn_3x64_gctx(**kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer) init_network_weights(net, init_type="normal", gain=0.02) return net @NETWORK_REGISTRY.register() def fcn_3x32_gctx_stn(image_size=32, **kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN( 3, 3, nc=32, n_blocks=3, norm_layer=norm_layer, stn=True, image_size=image_size ) init_network_weights(net, init_type="normal", gain=0.02) net.init_loc_layer() return net @NETWORK_REGISTRY.register() def fcn_3x64_gctx_stn(image_size=224, **kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN( 3, 3, nc=64, n_blocks=3, norm_layer=norm_layer, stn=True, image_size=image_size ) init_network_weights(net, init_type="normal", gain=0.02) net.init_loc_layer() return net ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py ================================================ from .mmd import MaximumMeanDiscrepancy from .dsbn import DSBN1d, DSBN2d from .mixup import mixup from .efdmix import ( EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix, crossdomain_efdmix, run_without_efdmix ) from .mixstyle import ( MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle, deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle ) from .transnorm import TransNorm1d, TransNorm2d from .sequential2 import Sequential2 from .reverse_grad import ReverseGrad from .cross_entropy import cross_entropy from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py ================================================ import torch from torch.nn import functional as F def cross_entropy(input, target, label_smooth=0, reduction="mean"): """Cross entropy loss. Args: input (torch.Tensor): logit matrix with shape of (batch, num_classes). target (torch.LongTensor): int label matrix. label_smooth (float, optional): label smoothing hyper-parameter. Default is 0. reduction (str, optional): how the losses for a mini-batch will be aggregated. Default is 'mean'. """ num_classes = input.shape[1] log_prob = F.log_softmax(input, dim=1) zeros = torch.zeros(log_prob.size()) target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1) target = target.type_as(input) target = (1-label_smooth) * target + label_smooth/num_classes loss = (-target * log_prob).sum(1) if reduction == "mean": return loss.mean() elif reduction == "sum": return loss.sum() elif reduction == "none": return loss else: raise ValueError ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py ================================================ import torch.nn as nn class _DSBN(nn.Module): """Domain Specific Batch Normalization. Args: num_features (int): number of features. n_domain (int): number of domains. bn_type (str): type of bn. Choices are ['1d', '2d']. """ def __init__(self, num_features, n_domain, bn_type): super().__init__() if bn_type == "1d": BN = nn.BatchNorm1d elif bn_type == "2d": BN = nn.BatchNorm2d else: raise ValueError self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain)) self.valid_domain_idxs = list(range(n_domain)) self.n_domain = n_domain self.domain_idx = 0 def select_bn(self, domain_idx=0): assert domain_idx in self.valid_domain_idxs self.domain_idx = domain_idx def forward(self, x): return self.bn[self.domain_idx](x) class DSBN1d(_DSBN): def __init__(self, num_features, n_domain): super().__init__(num_features, n_domain, "1d") class DSBN2d(_DSBN): def __init__(self, num_features, n_domain): super().__init__(num_features, n_domain, "2d") ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py ================================================ import random from contextlib import contextmanager import torch import torch.nn as nn def deactivate_efdmix(m): if type(m) == EFDMix: m.set_activation_status(False) def activate_efdmix(m): if type(m) == EFDMix: m.set_activation_status(True) def random_efdmix(m): if type(m) == EFDMix: m.update_mix_method("random") def crossdomain_efdmix(m): if type(m) == EFDMix: m.update_mix_method("crossdomain") @contextmanager def run_without_efdmix(model): # Assume MixStyle was initially activated try: model.apply(deactivate_efdmix) yield finally: model.apply(activate_efdmix) @contextmanager def run_with_efdmix(model, mix=None): # Assume MixStyle was initially deactivated if mix == "random": model.apply(random_efdmix) elif mix == "crossdomain": model.apply(crossdomain_efdmix) try: model.apply(activate_efdmix) yield finally: model.apply(deactivate_efdmix) class EFDMix(nn.Module): """EFDMix. Reference: Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self.mix = mix self._activated = True def __repr__(self): return ( f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" ) def set_activation_status(self, status=True): self._activated = status def update_mix_method(self, mix="random"): self.mix = mix def forward(self, x): if not self.training or not self._activated: return x if random.random() > self.p: return x B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3) x_view = x.view(B, C, -1) value_x, index_x = torch.sort(x_view) # sort inputs lmda = self.beta.sample((B, 1, 1)) lmda = lmda.to(x.device) if self.mix == "random": # random shuffle perm = torch.randperm(B) elif self.mix == "crossdomain": # split into two halves and swap the order perm = torch.arange(B - 1, -1, -1) # inverse index perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0) else: raise NotImplementedError inverse_index = index_x.argsort(-1) x_view_copy = value_x[perm].gather(-1, inverse_index) new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda) return new_x.view(B, C, W, H) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py ================================================ import random from contextlib import contextmanager import torch import torch.nn as nn def deactivate_mixstyle(m): if type(m) == MixStyle: m.set_activation_status(False) def activate_mixstyle(m): if type(m) == MixStyle: m.set_activation_status(True) def random_mixstyle(m): if type(m) == MixStyle: m.update_mix_method("random") def crossdomain_mixstyle(m): if type(m) == MixStyle: m.update_mix_method("crossdomain") @contextmanager def run_without_mixstyle(model): # Assume MixStyle was initially activated try: model.apply(deactivate_mixstyle) yield finally: model.apply(activate_mixstyle) @contextmanager def run_with_mixstyle(model, mix=None): # Assume MixStyle was initially deactivated if mix == "random": model.apply(random_mixstyle) elif mix == "crossdomain": model.apply(crossdomain_mixstyle) try: model.apply(activate_mixstyle) yield finally: model.apply(deactivate_mixstyle) class MixStyle(nn.Module): """MixStyle. Reference: Zhou et al. Domain Generalization with MixStyle. ICLR 2021. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self.mix = mix self._activated = True def __repr__(self): return ( f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" ) def set_activation_status(self, status=True): self._activated = status def update_mix_method(self, mix="random"): self.mix = mix def forward(self, x): if not self.training or not self._activated: return x if random.random() > self.p: return x B = x.size(0) mu = x.mean(dim=[2, 3], keepdim=True) var = x.var(dim=[2, 3], keepdim=True) sig = (var + self.eps).sqrt() mu, sig = mu.detach(), sig.detach() x_normed = (x-mu) / sig lmda = self.beta.sample((B, 1, 1, 1)) lmda = lmda.to(x.device) if self.mix == "random": # random shuffle perm = torch.randperm(B) elif self.mix == "crossdomain": # split into two halves and swap the order perm = torch.arange(B - 1, -1, -1) # inverse index perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0) else: raise NotImplementedError mu2, sig2 = mu[perm], sig[perm] mu_mix = mu*lmda + mu2 * (1-lmda) sig_mix = sig*lmda + sig2 * (1-lmda) return x_normed*sig_mix + mu_mix ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py ================================================ import torch def mixup(x1, x2, y1, y2, beta, preserve_order=False): """Mixup. Args: x1 (torch.Tensor): data with shape of (b, c, h, w). x2 (torch.Tensor): data with shape of (b, c, h, w). y1 (torch.Tensor): label with shape of (b, n). y2 (torch.Tensor): label with shape of (b, n). beta (float): hyper-parameter for Beta sampling. preserve_order (bool): apply lmda=max(lmda, 1-lmda). Default is False. """ lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1]) if preserve_order: lmda = torch.max(lmda, 1 - lmda) lmda = lmda.to(x1.device) xmix = x1*lmda + x2 * (1-lmda) lmda = lmda[:, :, 0, 0] ymix = y1*lmda + y2 * (1-lmda) return xmix, ymix ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F class MaximumMeanDiscrepancy(nn.Module): def __init__(self, kernel_type="rbf", normalize=False): super().__init__() self.kernel_type = kernel_type self.normalize = normalize def forward(self, x, y): # x, y: two batches of data with shape (batch, dim) # MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y') if self.normalize: x = F.normalize(x, dim=1) y = F.normalize(y, dim=1) if self.kernel_type == "linear": return self.linear_mmd(x, y) elif self.kernel_type == "poly": return self.poly_mmd(x, y) elif self.kernel_type == "rbf": return self.rbf_mmd(x, y) else: raise NotImplementedError def linear_mmd(self, x, y): # k(x, y) = x^T y k_xx = self.remove_self_distance(torch.mm(x, x.t())) k_yy = self.remove_self_distance(torch.mm(y, y.t())) k_xy = torch.mm(x, y.t()) return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2): # k(x, y) = (alpha * x^T y + c)^d k_xx = self.remove_self_distance(torch.mm(x, x.t())) k_xx = (alpha*k_xx + c).pow(d) k_yy = self.remove_self_distance(torch.mm(y, y.t())) k_yy = (alpha*k_yy + c).pow(d) k_xy = torch.mm(x, y.t()) k_xy = (alpha*k_xy + c).pow(d) return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() def rbf_mmd(self, x, y): # k_xx d_xx = self.euclidean_squared_distance(x, x) d_xx = self.remove_self_distance(d_xx) k_xx = self.rbf_kernel_mixture(d_xx) # k_yy d_yy = self.euclidean_squared_distance(y, y) d_yy = self.remove_self_distance(d_yy) k_yy = self.rbf_kernel_mixture(d_yy) # k_xy d_xy = self.euclidean_squared_distance(x, y) k_xy = self.rbf_kernel_mixture(d_xy) return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() @staticmethod def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]): K = 0 for sigma in sigmas: gamma = 1.0 / (2.0 * sigma**2) K += torch.exp(-gamma * exponent) return K @staticmethod def remove_self_distance(distmat): tmp_list = [] for i, row in enumerate(distmat): row1 = torch.cat([row[:i], row[i + 1:]]) tmp_list.append(row1) return torch.stack(tmp_list) @staticmethod def euclidean_squared_distance(x, y): m, n = x.size(0), y.size(0) distmat = ( torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() ) # distmat.addmm_(1, -2, x, y.t()) distmat.addmm_(x, y.t(), beta=1, alpha=-2) return distmat if __name__ == "__main__": mmd = MaximumMeanDiscrepancy(kernel_type="rbf") input1, input2 = torch.rand(3, 100), torch.rand(3, 100) d = mmd(input1, input2) print(d.item()) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F class OptimalTransport(nn.Module): @staticmethod def distance(batch1, batch2, dist_metric="cosine"): if dist_metric == "cosine": batch1 = F.normalize(batch1, p=2, dim=1) batch2 = F.normalize(batch2, p=2, dim=1) dist_mat = 1 - torch.mm(batch1, batch2.t()) elif dist_metric == "euclidean": m, n = batch1.size(0), batch2.size(0) dist_mat = ( torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t() ) dist_mat.addmm_( 1, -2, batch1, batch2.t() ) # squared euclidean distance elif dist_metric == "fast_euclidean": batch1 = batch1.unsqueeze(-2) batch2 = batch2.unsqueeze(-3) dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1) else: raise ValueError( "Unknown cost function: {}. Expected to " "be one of [cosine | euclidean]".format(dist_metric) ) return dist_mat class SinkhornDivergence(OptimalTransport): thre = 1e-3 def __init__( self, dist_metric="cosine", eps=0.01, max_iter=5, bp_to_sinkhorn=False ): super().__init__() self.dist_metric = dist_metric self.eps = eps self.max_iter = max_iter self.bp_to_sinkhorn = bp_to_sinkhorn def forward(self, x, y): # x, y: two batches of data with shape (batch, dim) W_xy = self.transport_cost(x, y) W_xx = self.transport_cost(x, x) W_yy = self.transport_cost(y, y) return 2*W_xy - W_xx - W_yy def transport_cost(self, x, y, return_pi=False): C = self.distance(x, y, dist_metric=self.dist_metric) pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre) if not self.bp_to_sinkhorn: pi = pi.detach() cost = torch.sum(pi * C) if return_pi: return cost, pi return cost @staticmethod def sinkhorn_iterate(C, eps, max_iter, thre): nx, ny = C.shape mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx) nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny) u = torch.zeros_like(mu) v = torch.zeros_like(nu) def M(_C, _u, _v): """Modified cost for logarithmic updates. Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon """ return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps real_iter = 0 # check if algorithm terminates before max_iter # Sinkhorn iterations for i in range(max_iter): u0 = u u = eps * ( torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1) ) + u v = ( eps * ( torch.log(nu + 1e-8) - torch.logsumexp(M(C, u, v).permute(1, 0), dim=1) ) + v ) err = (u - u0).abs().sum() real_iter += 1 if err.item() < thre: break # Transport plan pi = diag(a)*K*diag(b) return torch.exp(M(C, u, v)) class MinibatchEnergyDistance(SinkhornDivergence): def __init__( self, dist_metric="cosine", eps=0.01, max_iter=5, bp_to_sinkhorn=False ): super().__init__( dist_metric=dist_metric, eps=eps, max_iter=max_iter, bp_to_sinkhorn=bp_to_sinkhorn, ) def forward(self, x, y): x1, x2 = torch.split(x, x.size(0) // 2, dim=0) y1, y2 = torch.split(y, y.size(0) // 2, dim=0) cost = 0 cost += self.transport_cost(x1, y1) cost += self.transport_cost(x1, y2) cost += self.transport_cost(x2, y1) cost += self.transport_cost(x2, y2) cost -= 2 * self.transport_cost(x1, x2) cost -= 2 * self.transport_cost(y1, y2) return cost if __name__ == "__main__": # example: https://dfdazac.github.io/sinkhorn.html import numpy as np n_points = 5 a = np.array([[i, 0] for i in range(n_points)]) b = np.array([[i, 1] for i in range(n_points)]) x = torch.tensor(a, dtype=torch.float) y = torch.tensor(b, dtype=torch.float) sinkhorn = SinkhornDivergence( dist_metric="euclidean", eps=0.01, max_iter=5 ) dist, pi = sinkhorn.transport_cost(x, y, True) import pdb pdb.set_trace() ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py ================================================ import torch.nn as nn from torch.autograd import Function class _ReverseGrad(Function): @staticmethod def forward(ctx, input, grad_scaling): ctx.grad_scaling = grad_scaling return input.view_as(input) @staticmethod def backward(ctx, grad_output): grad_scaling = ctx.grad_scaling return -grad_scaling * grad_output, None reverse_grad = _ReverseGrad.apply class ReverseGrad(nn.Module): """Gradient reversal layer. It acts as an identity layer in the forward, but reverses the sign of the gradient in the backward. """ def forward(self, x, grad_scaling=1.0): assert (grad_scaling >= 0), "grad_scaling must be non-negative, " "but got {}".format( grad_scaling ) return reverse_grad(x, grad_scaling) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py ================================================ import torch.nn as nn class Sequential2(nn.Sequential): """An alternative sequential container to nn.Sequential, which accepts an arbitrary number of input arguments. """ def forward(self, *inputs): for module in self._modules.values(): if isinstance(inputs, tuple): inputs = module(*inputs) else: inputs = module(inputs) return inputs ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py ================================================ import torch import torch.nn as nn class _TransNorm(nn.Module): """Transferable normalization. Reference: - Wang et al. Transferable Normalization: Towards Improving Transferability of Deep Neural Networks. NeurIPS 2019. Args: num_features (int): number of features. eps (float): epsilon. momentum (float): value for updating running_mean and running_var. adaptive_alpha (bool): apply domain adaptive alpha. """ def __init__( self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True ): super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.adaptive_alpha = adaptive_alpha self.register_buffer("running_mean_s", torch.zeros(num_features)) self.register_buffer("running_var_s", torch.ones(num_features)) self.register_buffer("running_mean_t", torch.zeros(num_features)) self.register_buffer("running_var_t", torch.ones(num_features)) self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) def resnet_running_stats(self): self.running_mean_s.zero_() self.running_var_s.fill_(1) self.running_mean_t.zero_() self.running_var_t.fill_(1) def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def _check_input(self, x): raise NotImplementedError def _compute_alpha(self, mean_s, var_s, mean_t, var_t): C = self.num_features ratio_s = mean_s / (var_s + self.eps).sqrt() ratio_t = mean_t / (var_t + self.eps).sqrt() dist = (ratio_s - ratio_t).abs() dist_inv = 1 / (1+dist) return C * dist_inv / dist_inv.sum() def forward(self, input): self._check_input(input) C = self.num_features if input.dim() == 2: new_shape = (1, C) elif input.dim() == 4: new_shape = (1, C, 1, 1) else: raise ValueError weight = self.weight.view(*new_shape) bias = self.bias.view(*new_shape) if not self.training: mean_t = self.running_mean_t.view(*new_shape) var_t = self.running_var_t.view(*new_shape) output = (input-mean_t) / (var_t + self.eps).sqrt() output = output*weight + bias if self.adaptive_alpha: mean_s = self.running_mean_s.view(*new_shape) var_s = self.running_var_s.view(*new_shape) alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t) alpha = alpha.reshape(*new_shape) output = (1 + alpha.detach()) * output return output input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0) x_s = input_s.transpose(0, 1).reshape(C, -1) mean_s = x_s.mean(1) var_s = x_s.var(1) self.running_mean_s.mul_(self.momentum) self.running_mean_s.add_((1 - self.momentum) * mean_s.data) self.running_var_s.mul_(self.momentum) self.running_var_s.add_((1 - self.momentum) * var_s.data) mean_s = mean_s.reshape(*new_shape) var_s = var_s.reshape(*new_shape) output_s = (input_s-mean_s) / (var_s + self.eps).sqrt() output_s = output_s*weight + bias x_t = input_t.transpose(0, 1).reshape(C, -1) mean_t = x_t.mean(1) var_t = x_t.var(1) self.running_mean_t.mul_(self.momentum) self.running_mean_t.add_((1 - self.momentum) * mean_t.data) self.running_var_t.mul_(self.momentum) self.running_var_t.add_((1 - self.momentum) * var_t.data) mean_t = mean_t.reshape(*new_shape) var_t = var_t.reshape(*new_shape) output_t = (input_t-mean_t) / (var_t + self.eps).sqrt() output_t = output_t*weight + bias output = torch.cat([output_s, output_t], 0) if self.adaptive_alpha: alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t) alpha = alpha.reshape(*new_shape) output = (1 + alpha.detach()) * output return output class TransNorm1d(_TransNorm): def _check_input(self, x): if x.dim() != 2: raise ValueError( "Expected the input to be 2-D, " "but got {}-D".format(x.dim()) ) class TransNorm2d(_TransNorm): def _check_input(self, x): if x.dim() != 4: raise ValueError( "Expected the input to be 4-D, " "but got {}-D".format(x.dim()) ) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py ================================================ import numpy as np import torch def sharpen_prob(p, temperature=2): """Sharpening probability with a temperature. Args: p (torch.Tensor): probability matrix (batch_size, n_classes) temperature (float): temperature. """ p = p.pow(temperature) return p / p.sum(1, keepdim=True) def reverse_index(data, label): """Reverse order.""" inv_idx = torch.arange(data.size(0) - 1, -1, -1).long() return data[inv_idx], label[inv_idx] def shuffle_index(data, label): """Shuffle order.""" rnd_idx = torch.randperm(data.shape[0]) return data[rnd_idx], label[rnd_idx] def create_onehot(label, num_classes): """Create one-hot tensor. We suggest using nn.functional.one_hot. Args: label (torch.Tensor): 1-D tensor. num_classes (int): number of classes. """ onehot = torch.zeros(label.shape[0], num_classes) return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1) def sigmoid_rampup(current, rampup_length): """Exponential rampup. Args: current (int): current step. rampup_length (int): maximum step. """ assert rampup_length > 0 current = np.clip(current, 0.0, rampup_length) phase = 1.0 - current/rampup_length return float(np.exp(-5.0 * phase * phase)) def linear_rampup(current, rampup_length): """Linear rampup. Args: current (int): current step. rampup_length (int): maximum step. """ assert rampup_length > 0 ratio = np.clip(current / rampup_length, 0.0, 1.0) return float(ratio) def ema_model_update(model, ema_model, alpha): """Exponential moving average of model parameters. Args: model (nn.Module): model being trained. ema_model (nn.Module): ema of the model. alpha (float): ema decay rate. """ for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/optim/__init__.py ================================================ from .optimizer import build_optimizer from .lr_scheduler import build_lr_scheduler ================================================ FILE: Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py ================================================ """ Modified from https://github.com/KaiyangZhou/deep-person-reid """ import torch from torch.optim.lr_scheduler import _LRScheduler AVAI_SCHEDS = ["single_step", "multi_step", "cosine"] class _BaseWarmupScheduler(_LRScheduler): def __init__( self, optimizer, successor, warmup_epoch, last_epoch=-1, verbose=False ): self.successor = successor self.warmup_epoch = warmup_epoch super().__init__(optimizer, last_epoch, verbose) def get_lr(self): raise NotImplementedError def step(self, epoch=None): if self.last_epoch >= self.warmup_epoch: self.successor.step(epoch) self._last_lr = self.successor.get_last_lr() else: super().step(epoch) class ConstantWarmupScheduler(_BaseWarmupScheduler): def __init__( self, optimizer, successor, warmup_epoch, cons_lr, last_epoch=-1, verbose=False ): self.cons_lr = cons_lr super().__init__( optimizer, successor, warmup_epoch, last_epoch, verbose ) def get_lr(self): if self.last_epoch >= self.warmup_epoch: return self.successor.get_last_lr() return [self.cons_lr for _ in self.base_lrs] class LinearWarmupScheduler(_BaseWarmupScheduler): def __init__( self, optimizer, successor, warmup_epoch, min_lr, last_epoch=-1, verbose=False ): self.min_lr = min_lr super().__init__( optimizer, successor, warmup_epoch, last_epoch, verbose ) def get_lr(self): if self.last_epoch >= self.warmup_epoch: return self.successor.get_last_lr() if self.last_epoch == 0: return [self.min_lr for _ in self.base_lrs] return [ lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs ] def build_lr_scheduler(optimizer, optim_cfg): """A function wrapper for building a learning rate scheduler. Args: optimizer (Optimizer): an Optimizer. optim_cfg (CfgNode): optimization config. """ lr_scheduler = optim_cfg.LR_SCHEDULER stepsize = optim_cfg.STEPSIZE gamma = optim_cfg.GAMMA max_epoch = optim_cfg.MAX_EPOCH if lr_scheduler not in AVAI_SCHEDS: raise ValueError( "Unsupported scheduler: {}. Must be one of {}".format( lr_scheduler, AVAI_SCHEDS ) ) if lr_scheduler == "single_step": if isinstance(stepsize, (list, tuple)): stepsize = stepsize[-1] if not isinstance(stepsize, int): raise TypeError( "For single_step lr_scheduler, stepsize must " "be an integer, but got {}".format(type(stepsize)) ) if stepsize <= 0: stepsize = max_epoch scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=stepsize, gamma=gamma ) elif lr_scheduler == "multi_step": if not isinstance(stepsize, (list, tuple)): raise TypeError( "For multi_step lr_scheduler, stepsize must " "be a list, but got {}".format(type(stepsize)) ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=stepsize, gamma=gamma ) elif lr_scheduler == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(max_epoch) ) if optim_cfg.WARMUP_EPOCH > 0: if not optim_cfg.WARMUP_RECOUNT: scheduler.last_epoch = optim_cfg.WARMUP_EPOCH if optim_cfg.WARMUP_TYPE == "constant": scheduler = ConstantWarmupScheduler( optimizer, scheduler, optim_cfg.WARMUP_EPOCH, optim_cfg.WARMUP_CONS_LR ) elif optim_cfg.WARMUP_TYPE == "linear": scheduler = LinearWarmupScheduler( optimizer, scheduler, optim_cfg.WARMUP_EPOCH, optim_cfg.WARMUP_MIN_LR ) else: raise ValueError return scheduler ================================================ FILE: Dassl.ProGrad.pytorch/dassl/optim/optimizer.py ================================================ """ Modified from https://github.com/KaiyangZhou/deep-person-reid """ import warnings import torch import torch.nn as nn from .radam import RAdam AVAI_OPTIMS = ["adam", "amsgrad", "sgd", "rmsprop", "radam", "adamw"] def build_optimizer(model, optim_cfg): """A function wrapper for building an optimizer. Args: model (nn.Module or iterable): model. optim_cfg (CfgNode): optimization config. """ optim = optim_cfg.NAME lr = optim_cfg.LR weight_decay = optim_cfg.WEIGHT_DECAY momentum = optim_cfg.MOMENTUM sgd_dampening = optim_cfg.SGD_DAMPNING sgd_nesterov = optim_cfg.SGD_NESTEROV rmsprop_alpha = optim_cfg.RMSPROP_ALPHA adam_beta1 = optim_cfg.ADAM_BETA1 adam_beta2 = optim_cfg.ADAM_BETA2 staged_lr = optim_cfg.STAGED_LR new_layers = optim_cfg.NEW_LAYERS base_lr_mult = optim_cfg.BASE_LR_MULT if optim not in AVAI_OPTIMS: raise ValueError( "Unsupported optim: {}. Must be one of {}".format( optim, AVAI_OPTIMS ) ) if staged_lr: if not isinstance(model, nn.Module): raise TypeError( "When staged_lr is True, model given to " "build_optimizer() must be an instance of nn.Module" ) if isinstance(model, nn.DataParallel): model = model.module if isinstance(new_layers, str): if new_layers is None: warnings.warn( "new_layers is empty, therefore, staged_lr is useless" ) new_layers = [new_layers] base_params = [] base_layers = [] new_params = [] for name, module in model.named_children(): if name in new_layers: new_params += [p for p in module.parameters()] else: base_params += [p for p in module.parameters()] base_layers.append(name) param_groups = [ { "params": base_params, "lr": lr * base_lr_mult }, { "params": new_params }, ] else: if isinstance(model, nn.Module): param_groups = model.parameters() else: param_groups = model if optim == "adam": optimizer = torch.optim.Adam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) elif optim == "amsgrad": optimizer = torch.optim.Adam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), amsgrad=True, ) elif optim == "sgd": optimizer = torch.optim.SGD( param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, dampening=sgd_dampening, nesterov=sgd_nesterov, ) elif optim == "rmsprop": optimizer = torch.optim.RMSprop( param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, alpha=rmsprop_alpha, ) elif optim == "radam": optimizer = RAdam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) elif optim == "adamw": optimizer = torch.optim.AdamW( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) return optimizer ================================================ FILE: Dassl.ProGrad.pytorch/dassl/optim/radam.py ================================================ """ Imported from: https://github.com/LiyuanLucasLiu/RAdam https://arxiv.org/abs/1908.03265 @article{liu2019radam, title={On the Variance of the Adaptive Learning Rate and Beyond}, author={Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei}, journal={arXiv preprint arXiv:1908.03265}, year={2019} } """ import math import torch from torch.optim.optimizer import Optimizer class RAdam(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) self.degenerated_to_sgd = degenerated_to_sgd defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) self.buffer = [[None, None, None] for ind in range(10)] super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): super(RAdam, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError( "RAdam does not support sparse gradients" ) p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as( p_data_fp32 ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state["step"] += 1 buffered = self.buffer[int(state["step"] % 10)] if state["step"] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state["step"] beta2_t = beta2**state["step"] N_sma_max = 2 / (1-beta2) - 1 N_sma = N_sma_max - 2 * state["step" ] * beta2_t / (1-beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: step_size = math.sqrt( (1-beta2_t) * (N_sma-4) / (N_sma_max-4) * (N_sma-2) / N_sma * N_sma_max / (N_sma_max-2) ) / (1 - beta1**state["step"]) elif self.degenerated_to_sgd: step_size = 1.0 / (1 - beta1**state["step"]) else: step_size = -1 buffered[2] = step_size # more conservative since it's an approximated value if N_sma >= 5: if group["weight_decay"] != 0: p_data_fp32.add_( -group["weight_decay"] * group["lr"], p_data_fp32 ) denom = exp_avg_sq.sqrt().add_(group["eps"]) p_data_fp32.addcdiv_( -step_size * group["lr"], exp_avg, denom ) p.data.copy_(p_data_fp32) elif step_size > 0: if group["weight_decay"] != 0: p_data_fp32.add_( -group["weight_decay"] * group["lr"], p_data_fp32 ) p_data_fp32.add_(-step_size * group["lr"], exp_avg) p.data.copy_(p_data_fp32) return loss class PlainRAdam(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) self.degenerated_to_sgd = degenerated_to_sgd defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(PlainRAdam, self).__init__(params, defaults) def __setstate__(self, state): super(PlainRAdam, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError( "RAdam does not support sparse gradients" ) p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as( p_data_fp32 ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state["step"] += 1 beta2_t = beta2**state["step"] N_sma_max = 2 / (1-beta2) - 1 N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1-beta2_t) # more conservative since it's an approximated value if N_sma >= 5: if group["weight_decay"] != 0: p_data_fp32.add_( -group["weight_decay"] * group["lr"], p_data_fp32 ) step_size = ( group["lr"] * math.sqrt( (1-beta2_t) * (N_sma-4) / (N_sma_max-4) * (N_sma-2) / N_sma * N_sma_max / (N_sma_max-2) ) / (1 - beta1**state["step"]) ) denom = exp_avg_sq.sqrt().add_(group["eps"]) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) elif self.degenerated_to_sgd: if group["weight_decay"] != 0: p_data_fp32.add_( -group["weight_decay"] * group["lr"], p_data_fp32 ) step_size = group["lr"] / (1 - beta1**state["step"]) p_data_fp32.add_(-step_size, exp_avg) p.data.copy_(p_data_fp32) return loss class AdamW(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0 ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( "Invalid beta parameter at index 0: {}".format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup=warmup ) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): super(AdamW, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError( "Adam does not support sparse gradients, please consider SparseAdam instead" ) p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state["step"] = 0 state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as( p_data_fp32 ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) denom = exp_avg_sq.sqrt().add_(group["eps"]) bias_correction1 = 1 - beta1**state["step"] bias_correction2 = 1 - beta2**state["step"] if group["warmup"] > state["step"]: scheduled_lr = 1e-8 + state["step"] * group["lr"] / group[ "warmup"] else: scheduled_lr = group["lr"] step_size = ( scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 ) if group["weight_decay"] != 0: p_data_fp32.add_( -group["weight_decay"] * scheduled_lr, p_data_fp32 ) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) return loss ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/__init__.py ================================================ from .tools import * from .logger import * from .meters import * from .registry import * from .torchtools import * ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/logger.py ================================================ import os import sys import time import os.path as osp from .tools import mkdir_if_missing __all__ = ["Logger", "setup_logger"] class Logger: """Write console output to external text file. Imported from ``_ Args: fpath (str): directory to save logging file. Examples:: >>> import sys >>> import os.path as osp >>> save_dir = 'output/experiment-1' >>> log_name = 'train.log' >>> sys.stdout = Logger(osp.join(save_dir, log_name)) """ def __init__(self, fpath=None): self.console = sys.stdout self.file = None if fpath is not None: mkdir_if_missing(osp.dirname(fpath)) self.file = open(fpath, "w") def __del__(self): self.close() def __enter__(self): pass def __exit__(self, *args): self.close() def write(self, msg): self.console.write(msg) if self.file is not None: self.file.write(msg) def flush(self): self.console.flush() if self.file is not None: self.file.flush() os.fsync(self.file.fileno()) def close(self): self.console.close() if self.file is not None: self.file.close() def setup_logger(output=None): if output is None: return if output.endswith(".txt") or output.endswith(".log"): fpath = output else: fpath = osp.join(output, "log.txt") if osp.exists(fpath): # make sure the existing log file is not over-written fpath += time.strftime("-%Y-%m-%d-%H-%M-%S") sys.stdout = Logger(fpath) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/meters.py ================================================ from collections import defaultdict import torch __all__ = ["AverageMeter", "MetricMeter"] class AverageMeter: """Compute and store the average and current value. Examples:: >>> # 1. Initialize a meter to record loss >>> losses = AverageMeter() >>> # 2. Update meter after every mini-batch update >>> losses.update(loss_value, batch_size) """ def __init__(self, ema=False): """ Args: ema (bool, optional): apply exponential moving average. """ self.ema = ema self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): if isinstance(val, torch.Tensor): val = val.item() self.val = val self.sum += val * n self.count += n if self.ema: self.avg = self.avg * 0.9 + self.val * 0.1 else: self.avg = self.sum / self.count class MetricMeter: """Store the average and current value for a set of metrics. Examples:: >>> # 1. Create an instance of MetricMeter >>> metric = MetricMeter() >>> # 2. Update using a dictionary as input >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} >>> metric.update(input_dict) >>> # 3. Convert to string and print >>> print(str(metric)) """ def __init__(self, delimiter="\t"): self.meters = defaultdict(AverageMeter) self.delimiter = delimiter def update(self, input_dict): if input_dict is None: return if not isinstance(input_dict, dict): raise TypeError( "Input to MetricMeter.update() must be a dictionary" ) for k, v in input_dict.items(): if isinstance(v, torch.Tensor): v = v.item() self.meters[k].update(v) def __str__(self): output_str = [] for name, meter in self.meters.items(): output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") return self.delimiter.join(output_str) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/registry.py ================================================ """ Modified from https://github.com/facebookresearch/fvcore """ __all__ = ["Registry"] class Registry: """A registry providing name -> object mapping, to support custom modules. To create a registry (e.g. a backbone registry): .. code-block:: python BACKBONE_REGISTRY = Registry('BACKBONE') To register an object: .. code-block:: python @BACKBONE_REGISTRY.register() class MyBackbone(nn.Module): ... Or: .. code-block:: python BACKBONE_REGISTRY.register(MyBackbone) """ def __init__(self, name): self._name = name self._obj_map = dict() def _do_register(self, name, obj, force=False): if name in self._obj_map and not force: raise KeyError( 'An object named "{}" was already ' 'registered in "{}" registry'.format(name, self._name) ) self._obj_map[name] = obj def register(self, obj=None, force=False): if obj is None: # Used as a decorator def wrapper(fn_or_class): name = fn_or_class.__name__ self._do_register(name, fn_or_class, force=force) return fn_or_class return wrapper # Used as a function call name = obj.__name__ self._do_register(name, obj, force=force) def get(self, name): if name not in self._obj_map: raise KeyError( 'Object name "{}" does not exist ' 'in "{}" registry'.format(name, self._name) ) return self._obj_map[name] def registered_names(self): return list(self._obj_map.keys()) ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/tools.py ================================================ """ Modified from https://github.com/KaiyangZhou/deep-person-reid """ import os import sys import json import time import errno import numpy as np import random import os.path as osp import warnings from difflib import SequenceMatcher import PIL import torch from PIL import Image __all__ = [ "mkdir_if_missing", "check_isfile", "read_json", "write_json", "set_random_seed", "download_url", "read_image", "collect_env_info", "listdir_nohidden", "get_most_similar_str_to_a_from_b", "check_availability", "tolist_if_not", ] def mkdir_if_missing(dirname): """Create dirname if it is missing.""" if not osp.exists(dirname): try: os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise def check_isfile(fpath): """Check if the given path is a file. Args: fpath (str): file path. Returns: bool """ isfile = osp.isfile(fpath) if not isfile: warnings.warn('No file found at "{}"'.format(fpath)) return isfile def read_json(fpath): """Read json file from a path.""" with open(fpath, "r") as f: obj = json.load(f) return obj def write_json(obj, fpath): """Writes to a json file.""" mkdir_if_missing(osp.dirname(fpath)) with open(fpath, "w") as f: json.dump(obj, f, indent=4, separators=(",", ": ")) def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def download_url(url, dst): """Download file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024*duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write( "\r...%d%%, %d MB, %d KB/s, %d seconds passed" % (percent, progress_size / (1024*1024), speed, duration) ) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write("\n") def read_image(path): """Read image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ if not osp.exists(path): raise IOError("No file exists at {}".format(path)) while True: try: img = Image.open(path).convert("RGB") return img except IOError: print( "Cannot read image from {}, " "probably due to heavy IO. Will re-try".format(path) ) def collect_env_info(): """Return env info as a string. Code source: github.com/facebookresearch/maskrcnn-benchmark """ from torch.utils.collect_env import get_pretty_env_info env_str = get_pretty_env_info() env_str += "\n Pillow ({})".format(PIL.__version__) return env_str def listdir_nohidden(path, sort=False): """List non-hidden items in a directory. Args: path (str): directory path. sort (bool): sort the items. """ items = [f for f in os.listdir(path) if not f.startswith(".")] if sort: items.sort() return items def get_most_similar_str_to_a_from_b(a, b): """Return the most similar string to a in b. Args: a (str): probe string. b (list): a list of candidate strings. """ highest_sim = 0 chosen = None for candidate in b: sim = SequenceMatcher(None, a, candidate).ratio() if sim >= highest_sim: highest_sim = sim chosen = candidate return chosen def check_availability(requested, available): """Check if an element is available in a list. Args: requested (str): probe string. available (list): a list of available strings. """ if requested not in available: psb_ans = get_most_similar_str_to_a_from_b(requested, available) raise ValueError( "The requested one is expected " "to belong to {}, but got [{}] " "(do you mean [{}]?)".format(available, requested, psb_ans) ) def tolist_if_not(x): """Convert to a list.""" if not isinstance(x, list): x = [x] return x ================================================ FILE: Dassl.ProGrad.pytorch/dassl/utils/torchtools.py ================================================ """ Modified from https://github.com/KaiyangZhou/deep-person-reid """ import pickle import shutil import os.path as osp import warnings from functools import partial from collections import OrderedDict import torch import torch.nn as nn from .tools import mkdir_if_missing __all__ = [ "save_checkpoint", "load_checkpoint", "resume_from_checkpoint", "open_all_layers", "open_specified_layers", "count_num_param", "load_pretrained_weights", "init_network_weights", ] def save_checkpoint( state, save_dir, is_best=False, remove_module_from_keys=True, model_name="" ): r"""Save checkpoint. Args: state (dict): dictionary. save_dir (str): directory to save checkpoint. is_best (bool, optional): if True, this checkpoint will be copied and named ``model-best.pth.tar``. Default is False. remove_module_from_keys (bool, optional): whether to remove "module." from layer names. Default is True. model_name (str, optional): model name to save. Examples:: >>> state = { >>> 'state_dict': model.state_dict(), >>> 'epoch': 10, >>> 'optimizer': optimizer.state_dict() >>> } >>> save_checkpoint(state, 'log/my_model') """ mkdir_if_missing(save_dir) if remove_module_from_keys: # remove 'module.' in state_dict's keys state_dict = state["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith("module."): k = k[7:] new_state_dict[k] = v state["state_dict"] = new_state_dict # save model epoch = state["epoch"] if not model_name: model_name = "model.pth.tar-" + str(epoch) fpath = osp.join(save_dir, model_name) torch.save(state, fpath) print('Checkpoint saved to "{}"'.format(fpath)) # save current model name checkpoint_file = osp.join(save_dir, "checkpoint") checkpoint = open(checkpoint_file, "w+") checkpoint.write("{}\n".format(osp.basename(fpath))) checkpoint.close() if is_best: best_fpath = osp.join(osp.dirname(fpath), "model-best.pth.tar") shutil.copy(fpath, best_fpath) print('Best checkpoint saved to "{}"'.format(best_fpath)) def load_checkpoint(fpath): r"""Load checkpoint. ``UnicodeDecodeError`` can be well handled, which means python2-saved files can be read from python3. Args: fpath (str): path to checkpoint. Returns: dict Examples:: >>> fpath = 'log/my_model/model.pth.tar-10' >>> checkpoint = load_checkpoint(fpath) """ if fpath is None: raise ValueError("File path is None") if not osp.exists(fpath): raise FileNotFoundError('File is not found at "{}"'.format(fpath)) map_location = None if torch.cuda.is_available() else "cpu" try: checkpoint = torch.load(fpath, map_location=map_location) except UnicodeDecodeError: pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") checkpoint = torch.load( fpath, pickle_module=pickle, map_location=map_location ) except Exception: print('Unable to load checkpoint from "{}"'.format(fpath)) raise return checkpoint def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None): r"""Resume training from a checkpoint. This will load (1) model weights and (2) ``state_dict`` of optimizer if ``optimizer`` is not None. Args: fdir (str): directory where the model was saved. model (nn.Module): model. optimizer (Optimizer, optional): an Optimizer. scheduler (Scheduler, optional): an Scheduler. Returns: int: start_epoch. Examples:: >>> fdir = 'log/my_model' >>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler) """ with open(osp.join(fdir, "checkpoint"), "r") as checkpoint: model_name = checkpoint.readlines()[0].strip("\n") fpath = osp.join(fdir, model_name) print('Loading checkpoint from "{}"'.format(fpath)) checkpoint = load_checkpoint(fpath) model.load_state_dict(checkpoint["state_dict"]) print("Loaded model weights") if optimizer is not None and "optimizer" in checkpoint.keys(): optimizer.load_state_dict(checkpoint["optimizer"]) print("Loaded optimizer") if scheduler is not None and "scheduler" in checkpoint.keys(): scheduler.load_state_dict(checkpoint["scheduler"]) print("Loaded scheduler") start_epoch = checkpoint["epoch"] print("Previous epoch: {}".format(start_epoch)) return start_epoch def adjust_learning_rate( optimizer, base_lr, epoch, stepsize=20, gamma=0.1, linear_decay=False, final_lr=0, max_epoch=100, ): r"""Adjust learning rate. Deprecated. """ if linear_decay: # linearly decay learning rate from base_lr to final_lr frac_done = epoch / max_epoch lr = frac_done*final_lr + (1.0-frac_done) * base_lr else: # decay learning rate by gamma for every stepsize lr = base_lr * (gamma**(epoch // stepsize)) for param_group in optimizer.param_groups: param_group["lr"] = lr def set_bn_to_eval(m): r"""Set BatchNorm layers to eval mode.""" # 1. no update for running mean and var # 2. scale and shift parameters are still trainable classname = m.__class__.__name__ if classname.find("BatchNorm") != -1: m.eval() def open_all_layers(model): r"""Open all layers in model for training. Examples:: >>> open_all_layers(model) """ model.train() for p in model.parameters(): p.requires_grad = True def open_specified_layers(model, open_layers): r"""Open specified layers in model for training while keeping other layers frozen. Args: model (nn.Module): neural net model. open_layers (str or list): layers open for training. Examples:: >>> # Only model.classifier will be updated. >>> open_layers = 'classifier' >>> open_specified_layers(model, open_layers) >>> # Only model.fc and model.classifier will be updated. >>> open_layers = ['fc', 'classifier'] >>> open_specified_layers(model, open_layers) """ if isinstance(model, nn.DataParallel): model = model.module if isinstance(open_layers, str): open_layers = [open_layers] for layer in open_layers: assert hasattr( model, layer ), '"{}" is not an attribute of the model, please provide the correct name'.format( layer ) for name, module in model.named_children(): if name in open_layers: module.train() for p in module.parameters(): p.requires_grad = True else: module.eval() for p in module.parameters(): p.requires_grad = False def count_num_param(model): r"""Count number of parameters in a model. Args: model (nn.Module): network model. Examples:: >>> model_size = count_num_param(model) """ return sum(p.numel() for p in model.parameters()) def load_pretrained_weights(model, weight_path): r"""Load pretrianed weights to model. Features:: - Incompatible layers (unmatched in name or size) will be ignored. - Can automatically deal with keys containing "module.". Args: model (nn.Module): network model. weight_path (str): path to pretrained weights. Examples:: >>> weight_path = 'log/my_model/model-best.pth.tar' >>> load_pretrained_weights(model, weight_path) """ checkpoint = load_checkpoint(weight_path) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): if k.startswith("module."): k = k[7:] # discard module. if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) model_dict.update(new_state_dict) model.load_state_dict(model_dict) if len(matched_layers) == 0: warnings.warn( 'The pretrained weights "{}" cannot be loaded, ' "please check the key names manually " "(** ignored and continue **)".format(weight_path) ) else: print( 'Successfully loaded pretrained weights from "{}"'. format(weight_path) ) if len(discarded_layers) > 0: print( "** The following layers are discarded " "due to unmatched keys or layer size: {}". format(discarded_layers) ) def init_network_weights(model, init_type="normal", gain=0.02): def _init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and ( classname.find("Conv") != -1 or classname.find("Linear") != -1 ): if init_type == "normal": nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == "xavier": nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == "kaiming": nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": nn.init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError( "initialization method {} is not implemented". format(init_type) ) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm") != -1: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) elif classname.find("InstanceNorm") != -1: if m.weight is not None and m.bias is not None: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) model.apply(_init_func) ================================================ FILE: Dassl.ProGrad.pytorch/datasets/da/cifar_stl.py ================================================ import sys import pprint as pp import os.path as osp from torchvision.datasets import STL10, CIFAR10 from dassl.utils import mkdir_if_missing cifar_label2name = { 0: "airplane", 1: "car", # the original name was 'automobile' 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", # conflict class 7: "horse", 8: "ship", 9: "truck", } stl_label2name = { 0: "airplane", 1: "bird", 2: "car", 3: "cat", 4: "deer", 5: "dog", 6: "horse", 7: "monkey", # conflict class 8: "ship", 9: "truck", } new_name2label = { "airplane": 0, "bird": 1, "car": 2, "cat": 3, "deer": 4, "dog": 5, "horse": 6, "ship": 7, "truck": 8, } def extract_and_save_image(dataset, save_dir, discard, label2name): if osp.exists(save_dir): print('Folder "{}" already exists'.format(save_dir)) return print('Extracting images to "{}" ...'.format(save_dir)) mkdir_if_missing(save_dir) for i in range(len(dataset)): img, label = dataset[i] if label == discard: continue class_name = label2name[label] label_new = new_name2label[class_name] class_dir = osp.join( save_dir, str(label_new).zfill(3) + "_" + class_name ) mkdir_if_missing(class_dir) impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") img.save(impath) def download_and_prepare(name, root, discarded_label, label2name): print("Dataset: {}".format(name)) print("Root: {}".format(root)) print("Old labels:") pp.pprint(label2name) print("Discarded label: {}".format(discarded_label)) print("New labels:") pp.pprint(new_name2label) if name == "cifar": train = CIFAR10(root, train=True, download=True) test = CIFAR10(root, train=False) else: train = STL10(root, split="train", download=True) test = STL10(root, split="test") train_dir = osp.join(root, name, "train") test_dir = osp.join(root, name, "test") extract_and_save_image(train, train_dir, discarded_label, label2name) extract_and_save_image(test, test_dir, discarded_label, label2name) if __name__ == "__main__": download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name) download_and_prepare("stl", sys.argv[1], 7, stl_label2name) ================================================ FILE: Dassl.ProGrad.pytorch/datasets/da/digit5.py ================================================ import os import numpy as np import os.path as osp import argparse from PIL import Image from scipy.io import loadmat def mkdir_if_missing(directory): if not osp.exists(directory): os.makedirs(directory) def extract_and_save(data, label, save_dir): for i, (x, y) in enumerate(zip(data, label)): if x.shape[2] == 1: x = np.repeat(x, 3, axis=2) if y == 10: y = 0 x = Image.fromarray(x, mode="RGB") save_path = osp.join( save_dir, str(i + 1).zfill(6) + "_" + str(y) + ".jpg" ) x.save(save_path) def load_mnist(data_dir, raw_data_dir): filepath = osp.join(raw_data_dir, "mnist_data.mat") data = loadmat(filepath) train_data = np.reshape(data["train_32"], (55000, 32, 32, 1)) test_data = np.reshape(data["test_32"], (10000, 32, 32, 1)) train_label = np.nonzero(data["label_train"])[1] test_label = np.nonzero(data["label_test"])[1] return train_data, test_data, train_label, test_label def load_mnist_m(data_dir, raw_data_dir): filepath = osp.join(raw_data_dir, "mnistm_with_label.mat") data = loadmat(filepath) train_data = data["train"] test_data = data["test"] train_label = np.nonzero(data["label_train"])[1] test_label = np.nonzero(data["label_test"])[1] return train_data, test_data, train_label, test_label def load_svhn(data_dir, raw_data_dir): train = loadmat(osp.join(raw_data_dir, "svhn_train_32x32.mat")) train_data = train["X"].transpose(3, 0, 1, 2) train_label = train["y"][:, 0] test = loadmat(osp.join(raw_data_dir, "svhn_test_32x32.mat")) test_data = test["X"].transpose(3, 0, 1, 2) test_label = test["y"][:, 0] return train_data, test_data, train_label, test_label def load_syn(data_dir, raw_data_dir): filepath = osp.join(raw_data_dir, "syn_number.mat") data = loadmat(filepath) train_data = data["train_data"] test_data = data["test_data"] train_label = data["train_label"][:, 0] test_label = data["test_label"][:, 0] return train_data, test_data, train_label, test_label def load_usps(data_dir, raw_data_dir): filepath = osp.join(raw_data_dir, "usps_28x28.mat") data = loadmat(filepath)["dataset"] train_data = data[0][0].transpose(0, 2, 3, 1) test_data = data[1][0].transpose(0, 2, 3, 1) train_data *= 255 test_data *= 255 train_data = train_data.astype(np.uint8) test_data = test_data.astype(np.uint8) train_label = data[0][1][:, 0] test_label = data[1][1][:, 0] return train_data, test_data, train_label, test_label def main(data_dir): data_dir = osp.abspath(osp.expanduser(data_dir)) raw_data_dir = osp.join(data_dir, "Digit-Five") if not osp.exists(data_dir): raise FileNotFoundError('"{}" does not exist'.format(data_dir)) datasets = ["mnist", "mnist_m", "svhn", "syn", "usps"] for name in datasets: print("Creating {}".format(name)) output = eval("load_" + name)(data_dir, raw_data_dir) train_data, test_data, train_label, test_label = output print("# train: {}".format(train_data.shape[0])) print("# test: {}".format(test_data.shape[0])) train_dir = osp.join(data_dir, name, "train_images") mkdir_if_missing(train_dir) test_dir = osp.join(data_dir, name, "test_images") mkdir_if_missing(test_dir) extract_and_save(train_data, train_label, train_dir) extract_and_save(test_data, test_label, test_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "data_dir", type=str, help="directory containing Digit-Five/" ) args = parser.parse_args() main(args.data_dir) ================================================ FILE: Dassl.ProGrad.pytorch/datasets/da/visda17.sh ================================================ # ------------------------------------------------------------------------ # ROOT is the root directory where you put your domain datasets. # # Suppose you wanna put the dataset under $DATA, which stores all the # domain datasets, run the following command in your terminal to # download VisDa17: # # $ sh visda17.sh $DATA #------------------------------------------------------------------------ ROOT=$1 mkdir $ROOT/visda17 cd $ROOT/visda17 wget http://csr.bu.edu/ftp/visda17/clf/train.tar tar xvf train.tar wget http://csr.bu.edu/ftp/visda17/clf/validation.tar tar xvf validation.tar wget http://csr.bu.edu/ftp/visda17/clf/test.tar tar xvf test.tar wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt ================================================ FILE: Dassl.ProGrad.pytorch/datasets/dg/cifar_c.py ================================================ """ This script - creates a folder named "cifar10_c" under the same directory as 'CIFAR-10-C' - extracts images from .npy files and save them as .jpg. """ import os import sys import numpy as np import os.path as osp from PIL import Image from dassl.utils import mkdir_if_missing def extract_and_save(images, labels, level, dst): # level denotes the corruption intensity level (0-based) assert 0 <= level <= 4 for i in range(10000): real_i = i + level*10000 im = Image.fromarray(images[real_i]) label = int(labels[real_i]) category_dir = osp.join(dst, str(label).zfill(3)) mkdir_if_missing(category_dir) save_path = osp.join(category_dir, str(i + 1).zfill(5) + ".jpg") im.save(save_path) def main(npy_folder): npy_folder = osp.abspath(osp.expanduser(npy_folder)) dataset_cap = osp.basename(npy_folder) assert dataset_cap in ["CIFAR-10-C", "CIFAR-100-C"] if dataset_cap == "CIFAR-10-C": dataset = "cifar10_c" else: dataset = "cifar100_c" if not osp.exists(npy_folder): print('The given folder "{}" does not exist'.format(npy_folder)) root = osp.dirname(npy_folder) im_folder = osp.join(root, dataset) mkdir_if_missing(im_folder) dirnames = os.listdir(npy_folder) dirnames.remove("labels.npy") if "README.txt" in dirnames: dirnames.remove("README.txt") assert len(dirnames) == 19 labels = np.load(osp.join(npy_folder, "labels.npy")) for dirname in dirnames: corruption = dirname.split(".")[0] corruption_folder = osp.join(im_folder, corruption) mkdir_if_missing(corruption_folder) npy_filename = osp.join(npy_folder, dirname) images = np.load(npy_filename) assert images.shape[0] == 50000 for level in range(5): dst = osp.join(corruption_folder, str(level + 1)) mkdir_if_missing(dst) print('Saving images to "{}"'.format(dst)) extract_and_save(images, labels, level, dst) if __name__ == "__main__": # sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C main(sys.argv[1]) ================================================ FILE: Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py ================================================ import sys import os.path as osp from torchvision.datasets import SVHN, CIFAR10, CIFAR100 from dassl.utils import mkdir_if_missing def extract_and_save_image(dataset, save_dir): if osp.exists(save_dir): print('Folder "{}" already exists'.format(save_dir)) return print('Extracting images to "{}" ...'.format(save_dir)) mkdir_if_missing(save_dir) for i in range(len(dataset)): img, label = dataset[i] class_dir = osp.join(save_dir, str(label).zfill(3)) mkdir_if_missing(class_dir) impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") img.save(impath) def download_and_prepare(name, root): print("Dataset: {}".format(name)) print("Root: {}".format(root)) if name == "cifar10": train = CIFAR10(root, train=True, download=True) test = CIFAR10(root, train=False) elif name == "cifar100": train = CIFAR100(root, train=True, download=True) test = CIFAR100(root, train=False) elif name == "svhn": train = SVHN(root, split="train", download=True) test = SVHN(root, split="test", download=True) else: raise ValueError train_dir = osp.join(root, name, "train") test_dir = osp.join(root, name, "test") extract_and_save_image(train, train_dir) extract_and_save_image(test, test_dir) if __name__ == "__main__": download_and_prepare("cifar10", sys.argv[1]) download_and_prepare("cifar100", sys.argv[1]) download_and_prepare("svhn", sys.argv[1]) ================================================ FILE: Dassl.ProGrad.pytorch/datasets/ssl/stl10.py ================================================ import sys import os.path as osp from torchvision.datasets import STL10 from dassl.utils import mkdir_if_missing def extract_and_save_image(dataset, save_dir): if osp.exists(save_dir): print('Folder "{}" already exists'.format(save_dir)) return print('Extracting images to "{}" ...'.format(save_dir)) mkdir_if_missing(save_dir) for i in range(len(dataset)): img, label = dataset[i] if label == -1: label_name = "none" else: label_name = str(label) imname = str(i).zfill(6) + "_" + label_name + ".jpg" impath = osp.join(save_dir, imname) img.save(impath) def download_and_prepare(root): train = STL10(root, split="train", download=True) test = STL10(root, split="test") unlabeled = STL10(root, split="unlabeled") train_dir = osp.join(root, "train") test_dir = osp.join(root, "test") unlabeled_dir = osp.join(root, "unlabeled") extract_and_save_image(train, train_dir) extract_and_save_image(test, test_dir) extract_and_save_image(unlabeled, unlabeled_dir) if __name__ == "__main__": download_and_prepare(sys.argv[1]) ================================================ FILE: Dassl.ProGrad.pytorch/linter.sh ================================================ echo "Running isort" isort -y -sp . echo "Done" echo "Running yapf" yapf -i -r -vv -e build . echo "Done" echo "Running flake8" flake8 . echo "Done" ================================================ FILE: Dassl.ProGrad.pytorch/requirements.txt ================================================ flake8==3.7.9 yapf==0.29.0 isort==4.3.21 yacs gdown tb-nightly future scipy scikit-learn tqdm ================================================ FILE: Dassl.ProGrad.pytorch/setup.py ================================================ import numpy as np import os.path as osp from setuptools import setup, find_packages def readme(): with open('README.md') as f: content = f.read() return content def find_version(): version_file = 'dassl/__init__.py' with open(version_file, 'r') as f: exec(compile(f.read(), version_file, 'exec')) return locals()['__version__'] def numpy_include(): try: numpy_include = np.get_include() except AttributeError: numpy_include = np.get_numpy_include() return numpy_include def get_requirements(filename='requirements.txt'): here = osp.dirname(osp.realpath(__file__)) with open(osp.join(here, filename), 'r') as f: requires = [line.replace('\n', '') for line in f.readlines()] return requires setup( name='dassl', version=find_version(), description='Dassl: Domain adaptation and semi-supervised learning', author='Kaiyang Zhou', license='MIT', long_description=readme(), url='https://github.com/KaiyangZhou/Dassl.pytorch', packages=find_packages(), install_requires=get_requirements(), keywords=[ 'Domain Adaptation', 'Domain Generalization', 'Semi-Supervised Learning', 'Pytorch' ] ) ================================================ FILE: Dassl.ProGrad.pytorch/tools/parse_test_res.py ================================================ """ Goal --- 1. Read test results from log.txt files 2. Compute mean and std across different folders (seeds) Usage --- Assume the output files are saved under output/my_experiment, which contains results of different seeds, e.g., my_experiment/ seed1/ log.txt seed2/ log.txt seed3/ log.txt Run the following command from the root directory: $ python tools/parse_test_res.py output/my_experiment Add --ci95 to the argument if you wanna get 95% confidence interval instead of standard deviation: $ python tools/parse_test_res.py output/my_experiment --ci95 If my_experiment/ has the following structure, my_experiment/ exp-1/ seed1/ log.txt ... seed2/ log.txt ... seed3/ log.txt ... exp-2/ ... exp-3/ ... Run $ python tools/parse_test_res.py output/my_experiment --multi-exp """ import re import numpy as np import os.path as osp import argparse from collections import OrderedDict, defaultdict from dassl.utils import check_isfile, listdir_nohidden def compute_ci95(res): return 1.96 * np.std(res) / np.sqrt(len(res)) def parse_function(*metrics, directory="", args=None, end_signal=None): print(f"Parsing files in {directory}") subdirs = listdir_nohidden(directory, sort=True) outputs = [] for subdir in subdirs: fpath = osp.join(directory, subdir, "log.txt") assert check_isfile(fpath) good_to_go = False output = OrderedDict() with open(fpath, "r") as f: lines = f.readlines() for line in lines: line = line.strip() if line == end_signal: good_to_go = True for metric in metrics: match = metric["regex"].search(line) if match and good_to_go: if "file" not in output: output["file"] = fpath num = float(match.group(1)) name = metric["name"] output[name] = num if output: outputs.append(output) assert len(outputs) > 0, f"Nothing found in {directory}" metrics_results = defaultdict(list) for output in outputs: msg = "" for key, value in output.items(): if isinstance(value, float): msg += f"{key}: {value:.2f}%. " else: msg += f"{key}: {value}. " if key != "file": metrics_results[key].append(value) print(msg) output_results = OrderedDict() print("===") print(f"Summary of directory: {directory}") for key, values in metrics_results.items(): avg = np.mean(values) std = compute_ci95(values) if args.ci95 else np.std(values) print(f"* {key}: {avg:.2f}% +- {std:.2f}%") output_results[key] = avg print("===") return output_results def main(args, end_signal): metric = { "name": args.keyword, "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"), } if args.multi_exp: final_results = defaultdict(list) for directory in listdir_nohidden(args.directory, sort=True): directory = osp.join(args.directory, directory) results = parse_function( metric, directory=directory, args=args, end_signal=end_signal ) for key, value in results.items(): final_results[key].append(value) print("Average performance") for key, values in final_results.items(): avg = np.mean(values) print(f"* {key}: {avg:.2f}%") else: parse_function( metric, directory=args.directory, args=args, end_signal=end_signal ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("directory", type=str, help="path to directory") parser.add_argument( "--ci95", action="store_true", help=r"compute 95\% confidence interval" ) parser.add_argument( "--test-log", action="store_true", help="parse test-only logs" ) parser.add_argument( "--multi-exp", action="store_true", help="parse multiple experiments" ) parser.add_argument( "--keyword", default="accuracy", type=str, help="which keyword to extract" ) args = parser.parse_args() end_signal = "Finished training" if args.test_log: end_signal = "=> result" main(args, end_signal) ================================================ FILE: Dassl.ProGrad.pytorch/tools/replace_text.py ================================================ """ Replace text in python files. """ import glob import os.path as osp import argparse import fileinput EXTENSION = ".py" def is_python_file(filename): ext = osp.splitext(filename)[1] return ext == EXTENSION def update_file(filename, text_to_search, replacement_text): print("Processing {}".format(filename)) with fileinput.FileInput(filename, inplace=True, backup="") as file: for line in file: print(line.replace(text_to_search, replacement_text), end="") def recursive_update(directory, text_to_search, replacement_text): filenames = glob.glob(osp.join(directory, "*")) for filename in filenames: if osp.isfile(filename): if not is_python_file(filename): continue update_file(filename, text_to_search, replacement_text) elif osp.isdir(filename): recursive_update(filename, text_to_search, replacement_text) else: raise NotImplementedError def main(): parser = argparse.ArgumentParser() parser.add_argument( "file_or_dir", type=str, help="path to file or directory" ) parser.add_argument("text_to_search", type=str, help="name to be replaced") parser.add_argument("replacement_text", type=str, help="new name") parser.add_argument( "--ext", type=str, default=".py", help="file extension" ) args = parser.parse_args() file_or_dir = args.file_or_dir text_to_search = args.text_to_search replacement_text = args.replacement_text extension = args.ext global EXTENSION EXTENSION = extension if osp.isfile(file_or_dir): if not is_python_file(file_or_dir): return update_file(file_or_dir, text_to_search, replacement_text) elif osp.isdir(file_or_dir): recursive_update(file_or_dir, text_to_search, replacement_text) else: raise NotImplementedError if __name__ == "__main__": main() ================================================ FILE: Dassl.ProGrad.pytorch/tools/train.py ================================================ import argparse import torch from dassl.utils import setup_logger, set_random_seed, collect_env_info from dassl.config import get_cfg_default from dassl.engine import build_trainer def print_args(args, cfg): print("***************") print("** Arguments **") print("***************") optkeys = list(args.__dict__.keys()) optkeys.sort() for key in optkeys: print("{}: {}".format(key, args.__dict__[key])) print("************") print("** Config **") print("************") print(cfg) def reset_cfg(cfg, args): if args.root: cfg.DATASET.ROOT = args.root if args.output_dir: cfg.OUTPUT_DIR = args.output_dir if args.resume: cfg.RESUME = args.resume if args.seed: cfg.SEED = args.seed if args.source_domains: cfg.DATASET.SOURCE_DOMAINS = args.source_domains if args.target_domains: cfg.DATASET.TARGET_DOMAINS = args.target_domains if args.transforms: cfg.INPUT.TRANSFORMS = args.transforms if args.trainer: cfg.TRAINER.NAME = args.trainer if args.backbone: cfg.MODEL.BACKBONE.NAME = args.backbone if args.head: cfg.MODEL.HEAD.NAME = args.head def extend_cfg(cfg): """ Add new config variables. E.g. from yacs.config import CfgNode as CN cfg.TRAINER.MY_MODEL = CN() cfg.TRAINER.MY_MODEL.PARAM_A = 1. cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 cfg.TRAINER.MY_MODEL.PARAM_C = False """ pass def setup_cfg(args): cfg = get_cfg_default() extend_cfg(cfg) # 1. From the dataset config file if args.dataset_config_file: cfg.merge_from_file(args.dataset_config_file) # 2. From the method config file if args.config_file: cfg.merge_from_file(args.config_file) # 3. From input arguments reset_cfg(cfg, args) # 4. From optional input arguments cfg.merge_from_list(args.opts) cfg.freeze() return cfg def main(args): cfg = setup_cfg(args) if cfg.SEED >= 0: print("Setting fixed seed: {}".format(cfg.SEED)) set_random_seed(cfg.SEED) setup_logger(cfg.OUTPUT_DIR) if torch.cuda.is_available() and cfg.USE_CUDA: torch.backends.cudnn.benchmark = True print_args(args, cfg) print("Collecting env info ...") print("** System info **\n{}\n".format(collect_env_info())) trainer = build_trainer(cfg) if args.eval_only: trainer.load_model(args.model_dir, epoch=args.load_epoch) trainer.test() return if not args.no_train: trainer.train() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--root", type=str, default="", help="path to dataset") parser.add_argument( "--output-dir", type=str, default="", help="output directory" ) parser.add_argument( "--resume", type=str, default="", help="checkpoint directory (from which the training resumes)", ) parser.add_argument( "--seed", type=int, default=-1, help="only positive value enables a fixed seed" ) parser.add_argument( "--source-domains", type=str, nargs="+", help="source domains for DA/DG" ) parser.add_argument( "--target-domains", type=str, nargs="+", help="target domains for DA/DG" ) parser.add_argument( "--transforms", type=str, nargs="+", help="data augmentation methods" ) parser.add_argument( "--config-file", type=str, default="", help="path to config file" ) parser.add_argument( "--dataset-config-file", type=str, default="", help="path to config file for dataset setup", ) parser.add_argument( "--trainer", type=str, default="", help="name of trainer" ) parser.add_argument( "--backbone", type=str, default="", help="name of CNN backbone" ) parser.add_argument("--head", type=str, default="", help="name of head") parser.add_argument( "--eval-only", action="store_true", help="evaluation only" ) parser.add_argument( "--model-dir", type=str, default="", help="load model from this directory for eval-only mode", ) parser.add_argument( "--load-epoch", type=int, help="load model weights at this epoch for evaluation" ) parser.add_argument( "--no-train", action="store_true", help="do not call trainer.train()" ) parser.add_argument( "opts", default=None, nargs=argparse.REMAINDER, help="modify config options using the command-line", ) args = parser.parse_args() main(args) ================================================ FILE: ProGrad.public/.gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # Custom output/ debug.sh ================================================ FILE: ProGrad.public/DATASETS.md ================================================ # How to install datasets We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like ``` $DATA/ |–– imagenet/ |–– caltech-101/ |–– oxford_pets/ |–– stanford_cars/ ``` If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. Datasets list: - [ImageNet](#imagenet) - [Caltech101](#caltech101) - [OxfordPets](#oxfordpets) - [StanfordCars](#stanfordcars) - [Flowers102](#flowers102) - [Food101](#food101) - [FGVCAircraft](#fgvcaircraft) - [SUN397](#sun397) - [DTD](#dtd) - [EuroSAT](#eurosat) - [UCF101](#ucf101) - [ImageNetV2](#imagenetv2) - [ImageNet-Sketch](#imagenet-sketch) - [ImageNet-A](#imagenet-a) - [ImageNet-R](#imagenet-r) The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. ### ImageNet - Create a folder named `imagenet/` under `$DATA`. - Create `images/` under `imagenet/`. - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like ``` imagenet/ |–– images/ | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. | |–– val/ ``` - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). ### Caltech101 - Create a folder named `caltech-101/` under `$DATA`. - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. The directory structure should look like ``` caltech-101/ |–– 101_ObjectCategories/ |–– split_zhou_Caltech101.json ``` ### OxfordPets - Create a folder named `oxford_pets/` under `$DATA`. - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). The directory structure should look like ``` oxford_pets/ |–– images/ |–– annotations/ |–– split_zhou_OxfordPets.json ``` ### StanfordCars - Create a folder named `stanford_cars/` under `$DATA`. - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). The directory structure should look like ``` stanford_cars/ |–– cars_test\ |–– cars_test_annos_withlabels.mat |–– cars_train\ |–– devkit\ |–– split_zhou_StanfordCars.json ``` ### Flowers102 - Create a folder named `oxford_flowers/` under `$DATA`. - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). The directory structure should look like ``` oxford_flowers/ |–– cat_to_name.json |–– imagelabels.mat |–– jpg/ |–– split_zhou_OxfordFlowers.json ``` ### Food101 - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). The directory structure should look like ``` food-101/ |–– images/ |–– license_agreement.txt |–– meta/ |–– README.txt |–– split_zhou_Food101.json ``` ### FGVCAircraft - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. The directory structure should look like ``` fgvc_aircraft/ |–– images/ |–– ... # a bunch of .txt files ``` ### SUN397 - Create a folder named `sun397/` under `$DATA`. - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. - Extract these files under `$DATA/sun397/`. - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). The directory structure should look like ``` sun397/ |–– SUN397/ |–– split_zhou_SUN397.json |–– ... # a bunch of .txt files ``` ### DTD - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). The directory structure should look like ``` dtd/ |–– images/ |–– imdb/ |–– labels/ |–– split_zhou_DescribableTextures.json ``` ### EuroSAT - Create a folder named `eurosat/` under `$DATA`. - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). The directory structure should look like ``` eurosat/ |–– 2750/ |–– split_zhou_EuroSAT.json ``` ### UCF101 - Create a folder named `ucf101/` under `$DATA`. - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). The directory structure should look like ``` ucf101/ |–– UCF-101-midframes/ |–– split_zhou_UCF101.json ``` ### ImageNetV2 - Create a folder named `imagenetv2/` under `$DATA`. - Go to this github repo https://github.com/modestyachts/ImageNetV2. - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`. - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`. The directory structure should look like ``` imagenetv2/ |–– imagenetv2-matched-frequency-format-val/ |–– classnames.txt ``` ### ImageNet-Sketch - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. - Extract the dataset to `$DATA/imagenet-sketch`. - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. The directory structure should look like ``` imagenet-sketch/ |–– images/ # contains 1,000 folders whose names have the format of n* |–– classnames.txt ``` ### ImageNet-A - Create a folder named `imagenet-adversarial/` under `$DATA`. - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`. - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`. The directory structure should look like ``` imagenet-adversarial/ |–– imagenet-a/ # contains 200 folders whose names have the format of n* |–– classnames.txt ``` ### ImageNet-R - Create a folder named `imagenet-rendition/` under `$DATA`. - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`. - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`. The directory structure should look like ``` imagenet-rendition/ |–– imagenet-r/ # contains 200 folders whose names have the format of n* |–– classnames.txt ``` ================================================ FILE: ProGrad.public/LICENSE ================================================ MIT License Copyright (c) 2021 Kaiyang Zhou Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: ProGrad.public/README.md ================================================ # How to Run ## GPU memory needed All the experiments is able to run on a single graphic card. However, **if you want to get results on ImageNet, the memory on any single graphic card should be larger than 24 GB.** Around 12 GB is enough for other datasets. ## How to Install This code is built on top of the toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch). But we have some modification on it. So please install the provided Dassl.ProGrad.pytorch. Go the the folder Dassl.ProGrad.pytorch provided in the appendix, and prepare the environment as follows: ``` # Create a conda environment conda create -n dassl python=3.7 # Activate the environment conda activate dassl # Install dependencies pip install -r requirements.txt # Install torch (version >= 1.7.1) and torchvision # Please make sure you have installed the gpu version due to the speed. # For example: conda install pytorch torchvision cudatoolkit=10.1 -c pytorch # Install this library (no need to re-build if the source code is modified) python setup.py develop ``` After that, run `pip install -r requirements.txt` under `ProGrad.public/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go. Follow [DATASETS.md](DATASETS.md) to install the datasets. ## Few-shot setting on 11 datasets Basic format: ``` bash main.sh ${DATASET_NAME} ${CONFIG_NAME} end ${CONTEXT_TOKENS_NUMBER} ${SHOTS} False ``` For example, to run 1, 2, 4, 8, and 16 shots on stanford_cars, **CLIP + CoOp (M=16, end)**: - 1 shot: `bash main.sh stanford_cars rn50_ep50 end 16 1 False` - 2 shots: `bash main.sh stanford_cars rn50_ep100 end 16 2 False` - 4 shots: `bash main.sh stanford_cars rn50_ep100 end 16 4 False` - 8 shots: `bash main.sh stanford_cars rn50 end 16 8 False` - 16 shots: `bash main.sh stanford_cars rn50 end 16 8 False` **CLIP + CoOp + ProGrad**: **Please take note that the 8-shots and 16-shots results on Flowers102, DTD, and EuroSAT are gotten with lambda as 0.8.** To get the results in our paper, please change the variable LAMBDA in prograd.sh from 1.0 to 0.8. - 1 shot: `bash prograd.sh stanford_cars rn50_ep50 end 16 1 False` - 2 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 2 False` - 4 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 4 False` - 8 shots: `bash prograd.sh stanford_cars rn50 end 16 8 False` - 16 shots: `bash prograd.sh stanford_cars rn50 end 16 16 False` ``` output |–– caltech101/ | |–– CoOp/ | | |–– rn50_16shots/ | | | |–– nctx16_cscFalse_ctpend/ | | | | |–– seed1/ | | | | |–– seed2/ | | | | |–– seed3/ | | |–– rn50_8shots/ | | | |–– nctx16_cscFalse_ctpend/ | | | | |–– seed1/ | | | | |–– seed2/ | | | | |–– seed3/ ``` To calculate the average results for the folder `rn50_16shots/nctx16_cscFalse_ctpend/`, you can run ```bash python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend ``` Then, you will see something like this in your terminal ```bash Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%. file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%. file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%. === Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend * accuracy: 92.00% +- 0.15% * error: 8.00% +- 0.15% === ``` **How to visualize nearest words for the learned context tokens?** All you need is `interpret_prompt.py`. Say the learned tokens are saved in `a/b/c/prompt_learner/model.pth.tar` and you would like to see the top-3 nearest words for each token. In this case, run `python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3` ## Robustness to Distribution Shift To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: `imagenetv2`, `imagenet-sketch`, `imagenet-a` and `imagenet-r`. The command is provided in `scripts/eval.sh`. The key arguments are `--model-dir`, `--load-epoch` and `--eval-only`. `--model-dir` indicates the directory where the models are saved (i.e. the entire folder containing `log.txt`, the tensorboard file and `prompt_learner/`). `--load-epoch` tells the code to load the model saved at a specific epoch, like `--load-epoch 50` for ImageNet for more details). For example, to evaluate `CLIP + CoOp (M=16, end)` on ImageNetV2, you can do ```bash # Don't need to use rn5_ep50 here as no training is performed bash eval.sh imagenetv2 rn50 ``` If you want to get the results of our method, simply change the TRAINER to `ProGrad`. The default setting is `SHOTS=4`. Feel free to modify the script. Again, you can use `parse_test_res.py` to automate the calculation of average performance. This time you should append `--test-log`, e.g., `python parse_test_res.py directory --test-log`. ## Zero-Shot CLIP See `CoOp/scripts/zeroshot.sh`. ## Generalization From Base to New Classes You will need `base2new_train_main.sh`, `base2new_test_main.sh`, `base2new_train_prograd.sh`, and `base2new_test_prograd.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have only one input argument, i.e., `DATASET`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`. The scripts with postfix `prograd.sh` are used for our proposed method, while the ones with the postfix `main.sh` are used for CoOp. Below we provide an example on how to evaluate the model on ImageNet. ```bash bash base2new_train_prograd.sh stanford_cars bash base2new_test_prograd.sh stanford_cars ``` **If you want to test results on ImageNet, remember to change the CFG from "rn50_ep100" to "rn50_ep50", and change the LOADEP from 100 to 50 in the corresponding script.** When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation using the aforementioned commands, you would get ``` output |–– base2new/ | |–– test_new/ | | |–– stanford_cars/ | | | |–– shots_16/ | | | | |–– CoCoOp/ | | | | | |–– rn50_ep100/ | | | | | | |–– seed1/ | | | | | | |–– seed2/ | | | | | | |–– seed3/ | |–– train_base/ | | |–– stanford_cars/ | | | |–– shots_16/ | | | | |–– CoCoOp/ | | | | | |–– rn50_ep100/ | | | | | | |–– seed1/ | | | | | | |–– seed2/ | | | | | | |–– seed3/ ``` Then, to get the average performance on the base classes, run ```bash python parse_test_res.py output/base2new/train_base/stanford_cars/shots_16/CoCoOp/rn50_ep100 ``` To get the average performance on the new classes, run ```bash python parse_test_res.py output/base2new/test_new/stanford_cars/shots_16/CoCoOp/rn50_ep100 --test-log ``` ================================================ FILE: ProGrad.public/clip/__init__.py ================================================ from .clip import * ================================================ FILE: ProGrad.public/clip/clip.py ================================================ import hashlib import os import urllib import warnings from typing import Union, List import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC if torch.__version__.split(".") < ["1", "7", "1"]: warnings.warn("PyTorch version 1.7.1 or higher is recommended") __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() _MODELS = { "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", } def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError( f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError( f"Model has been downloaded but the SHA256 checksum does not not match" ) return download_target def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys()) def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model or more hackable non-JIT model (default). Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: model_path = _download(_MODELS[name]) elif os.path.isfile(name): model_path = name else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() state_dict = None except RuntimeError: # loading saved state dict if jit: warnings.warn( f"File {model_path} is not a JIT archive. Loading as a state dict instead" ) jit = False state_dict = torch.load(model_path, map_location="cpu") if not jit: model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu": model.float() return model, _transform(model.visual.input_resolution) # patch the device names device_holder = torch.jit.trace( lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [ n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n) ][-1] def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str( node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) patch_device(model.encode_image) patch_device(model.encode_text) # patch dtype to float32 on CPU if str(device) == "cpu": float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_float(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [ 1, 2 ]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float) patch_float(model.encode_image) patch_float(model.encode_text) model.float() return model, _transform(model.input_resolution.item()) def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError( f"Input {texts[i]} is too long for context length {context_length}" ) result[i, :len(tokens)] = torch.tensor(tokens) return result ================================================ FILE: ProGrad.public/clip/model.py ================================================ from collections import OrderedDict from typing import Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential( OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) def forward(self, x: torch.Tensor): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter( torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat( [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): super().__init__() self.output_dim = output_dim self.input_resolution = input_resolution # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.avgpool = nn.AvgPool2d(2) self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): def stem(x): for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential( OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model))])) self.ln_2 = LayerNorm(d_model) self.attn_mask = attn_mask def attention(self, x: torch.Tensor): self.attn_mask = self.attn_mask.to( dtype=x.dtype, device=x.device) if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers) ]) def forward(self, x: torch.Tensor): return self.resblocks(x) class VisionTransformer(nn.Module): def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): super().__init__() self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) scale = width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn( (input_resolution // patch_size)**2 + 1, width)) self.ln_pre = LayerNorm(width) self.transformer = Transformer(width, layers, heads) self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([ self.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x ], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj return x class CLIP(nn.Module): def __init__( self, embed_dim: int, # vision image_resolution: int, vision_layers: Union[Tuple[int, int, int, int], int], vision_width: int, vision_patch_size: int, # text context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int, transformer_layers: int): super().__init__() self.context_length = context_length if isinstance(vision_layers, (tuple, list)): vision_heads = vision_width * 32 // 64 self.visual = ModifiedResNet(layers=vision_layers, output_dim=embed_dim, heads=vision_heads, input_resolution=image_resolution, width=vision_width) else: vision_heads = vision_width // 64 self.visual = VisionTransformer(input_resolution=image_resolution, patch_size=vision_patch_size, width=vision_width, layers=vision_layers, heads=vision_heads, output_dim=embed_dim) self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads, attn_mask=self.build_attention_mask()) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter( torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter( torch.empty(transformer_width, embed_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() def initialize_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if isinstance(self.visual, ModifiedResNet): if self.visual.attnpool is not None: std = self.visual.attnpool.c_proj.in_features**-0.5 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) for resnet_block in [ self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4 ]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) proj_std = (self.transformer.width**-0.5) * ( (2 * self.transformer.layers)**-0.5) attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width)**-0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image): return self.visual(image.type(self.dtype)) def encode_text(self, text): x = self.token_embedding(text).type( self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x def forward(self, image, text): image_features = self.encode_image(image) text_features = self.encode_text(text) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logit_scale * text_features @ image_features.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v" ]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16) def build_model(state_dict: dict): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len([ k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") ]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round( (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5) image_resolution = vision_patch_size * grid_size else: counts: list = [ len( set( k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4] ] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round( (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1)**0.5) vision_patch_size = None assert output_width**2 + 1 == state_dict[ "visual.attnpool.positional_embedding"].shape[0] image_resolution = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len( set( k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) for key in ["input_resolution", "context_length", "vocab_size"]: if key in state_dict: del state_dict[key] convert_weights(model) model.load_state_dict(state_dict) return model.eval() ================================================ FILE: ProGrad.public/clip/simple_tokenizer.py ================================================ import gzip import html import os from functools import lru_cache import ftfy import regex as re @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~") + 1)) + list(range( ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v + '' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = { '<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>' } self.pat = re.compile( r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + (token[-1] + '', ) pairs = get_pairs(word) if not pairs: return token + '' while True: bigram = min( pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word) - 1 and word[ i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text ]).decode('utf-8', errors="replace").replace('', ' ') return text ================================================ FILE: ProGrad.public/configs/datasets/caltech101.yaml ================================================ DATASET: NAME: "Caltech101" ================================================ FILE: ProGrad.public/configs/datasets/dtd.yaml ================================================ DATASET: NAME: "DescribableTextures" ================================================ FILE: ProGrad.public/configs/datasets/eurosat.yaml ================================================ DATASET: NAME: "EuroSAT" ================================================ FILE: ProGrad.public/configs/datasets/fgvc_aircraft.yaml ================================================ DATASET: NAME: "FGVCAircraft" ================================================ FILE: ProGrad.public/configs/datasets/food101.yaml ================================================ DATASET: NAME: "Food101" ================================================ FILE: ProGrad.public/configs/datasets/imagenet.yaml ================================================ DATASET: NAME: "ImageNet" ================================================ FILE: ProGrad.public/configs/datasets/imagenet_a.yaml ================================================ DATASET: NAME: "ImageNetA" ================================================ FILE: ProGrad.public/configs/datasets/imagenet_r.yaml ================================================ DATASET: NAME: "ImageNetR" ================================================ FILE: ProGrad.public/configs/datasets/imagenet_sketch.yaml ================================================ DATASET: NAME: "ImageNetSketch" ================================================ FILE: ProGrad.public/configs/datasets/imagenetv2.yaml ================================================ DATASET: NAME: "ImageNetV2" ================================================ FILE: ProGrad.public/configs/datasets/oxford_flowers.yaml ================================================ DATASET: NAME: "OxfordFlowers" ================================================ FILE: ProGrad.public/configs/datasets/oxford_pets.yaml ================================================ DATASET: NAME: "OxfordPets" ================================================ FILE: ProGrad.public/configs/datasets/stanford_cars.yaml ================================================ DATASET: NAME: "StanfordCars" ================================================ FILE: ProGrad.public/configs/datasets/sun397.yaml ================================================ DATASET: NAME: "SUN397" ================================================ FILE: ProGrad.public/configs/datasets/ucf101.yaml ================================================ DATASET: NAME: "UCF101" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 10 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "RN50" TRAINER: COCOOP: N_CTX: 4 CTX_INIT: True PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_ep100_init.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 100 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "RN50" TRAINER: COCOOP: N_CTX: 16 CTX_INIT: True PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_ep50.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 50 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "RN50" TRAINER: COCOOP: N_CTX: 16 CTX_INIT: True PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 10 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "ViT-B/16" TRAINER: COCOOP: N_CTX: 16 CTX_INIT: "" PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 10 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "ViT-B/16" TRAINER: COCOOP: N_CTX: 4 CTX_INIT: "" PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 10 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "ViT-B/16" TRAINER: COCOOP: N_CTX: 4 CTX_INIT: "a photo of a" PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 1 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 10 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 20 MODEL: BACKBONE: NAME: "ViT-B/16" TRAINER: COCOOP: N_CTX: 8 CTX_INIT: "" PREC: "fp16" ================================================ FILE: ProGrad.public/configs/trainers/CoOp/rn50.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 200 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" ================================================ FILE: ProGrad.public/configs/trainers/CoOp/rn50_ep100.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 100 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" ================================================ FILE: ProGrad.public/configs/trainers/CoOp/rn50_ep50.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 50 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" ================================================ FILE: ProGrad.public/configs/trainers/CoOp/rn50_val.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 32 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] MODEL: BACKBONE: NAME: "RN50" ================================================ FILE: ProGrad.public/configs/trainers/ProGrad/rn50.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 200 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 LOSS: NAME: "prograd" T: 1.0 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" TRAINER: COOP: CTX_INIT: True ================================================ FILE: ProGrad.public/configs/trainers/ProGrad/rn50_ep100.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 100 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 LOSS: NAME: "prograd" T: 1.0 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" TRAINER: COOP: CTX_INIT: True ================================================ FILE: ProGrad.public/configs/trainers/ProGrad/rn50_ep50.yaml ================================================ DATALOADER: TRAIN_X: BATCH_SIZE: 32 TEST: BATCH_SIZE: 100 NUM_WORKERS: 8 INPUT: SIZE: (224, 224) INTERPOLATION: "bicubic" PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] OPTIM: NAME: "sgd" LR: 0.002 MAX_EPOCH: 50 LR_SCHEDULER: "cosine" WARMUP_EPOCH: 1 WARMUP_TYPE: "constant" WARMUP_CONS_LR: 1e-5 LOSS: NAME: "prograd" T: 1.0 TRAIN: PRINT_FREQ: 5 MODEL: BACKBONE: NAME: "RN50" TRAINER: COOP: CTX_INIT: True ================================================ FILE: ProGrad.public/datasets/__init__.py ================================================ ================================================ FILE: ProGrad.public/datasets/caltech101.py ================================================ import os import pickle from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets from .dtd import DescribableTextures as DTD IGNORED = ["BACKGROUND_Google", "Faces_easy"] NEW_CNAMES = { "airplanes": "airplane", "Faces": "face", "Leopards": "leopard", "Motorbikes": "motorbike", } @DATASET_REGISTRY.register() class Caltech101(DatasetBase): dataset_dir = "caltech-101" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) ================================================ FILE: ProGrad.public/datasets/dtd.py ================================================ import os import pickle import random from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden, mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class DescribableTextures(DatasetBase): dataset_dir = "dtd" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: train, val, test = self.read_and_split_data(self.image_dir) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) @staticmethod def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): # The data are supposed to be organized into the following structure # ============= # images/ # dog/ # cat/ # horse/ # ============= categories = listdir_nohidden(image_dir) categories = [c for c in categories if c not in ignored] categories.sort() p_tst = 1 - p_trn - p_val print( f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test" ) def _collate(ims, y, c): items = [] for im in ims: item = Datum(impath=im, label=y, classname=c) # is already 0-based items.append(item) return items train, val, test = [], [], [] for label, category in enumerate(categories): category_dir = os.path.join(image_dir, category) images = listdir_nohidden(category_dir) images = [os.path.join(category_dir, im) for im in images] random.shuffle(images) n_total = len(images) n_train = round(n_total * p_trn) n_val = round(n_total * p_val) n_test = n_total - n_train - n_val assert n_train > 0 and n_val > 0 and n_test > 0 if new_cnames is not None and category in new_cnames: category = new_cnames[category] train.extend(_collate(images[:n_train], label, category)) val.extend( _collate(images[n_train:n_train + n_val], label, category)) test.extend(_collate(images[n_train + n_val:], label, category)) return train, val, test ================================================ FILE: ProGrad.public/datasets/eurosat.py ================================================ import os import pickle from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets from .dtd import DescribableTextures as DTD NEW_CNAMES = { "AnnualCrop": "Annual Crop Land", "Forest": "Forest", "HerbaceousVegetation": "Herbaceous Vegetation Land", "Highway": "Highway or Road", "Industrial": "Industrial Buildings", "Pasture": "Pasture Land", "PermanentCrop": "Permanent Crop Land", "Residential": "Residential Buildings", "River": "River", "SeaLake": "Sea or Lake", } @DATASET_REGISTRY.register() class EuroSAT(DatasetBase): dataset_dir = "eurosat" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "2750") self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def update_classname(self, dataset_old): dataset_new = [] for item_old in dataset_old: cname_old = item_old.classname cname_new = NEW_CLASSNAMES[cname_old] item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) dataset_new.append(item_new) return dataset_new ================================================ FILE: ProGrad.public/datasets/fgvc_aircraft.py ================================================ import os import pickle from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class FGVCAircraft(DatasetBase): dataset_dir = "fgvc_aircraft" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) classnames = [] with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: lines = f.readlines() for line in lines: classnames.append(line.strip()) cname2lab = {c: i for i, c in enumerate(classnames)} train = self.read_data(cname2lab, "images_variant_train.txt") val = self.read_data(cname2lab, "images_variant_val.txt") test = self.read_data(cname2lab, "images_variant_test.txt") num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self, cname2lab, split_file): filepath = os.path.join(self.dataset_dir, split_file) items = [] with open(filepath, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split(" ") imname = line[0] + ".jpg" classname = " ".join(line[1:]) impath = os.path.join(self.image_dir, imname) label = cname2lab[classname] item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/food101.py ================================================ import os import pickle from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets from .dtd import DescribableTextures as DTD @DATASET_REGISTRY.register() class Food101(DatasetBase): dataset_dir = "food-101" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: train, val, test = DTD.read_and_split_data(self.image_dir) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) ================================================ FILE: ProGrad.public/datasets/imagenet.py ================================================ import os import pickle from collections import OrderedDict from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden, mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class ImageNet(DatasetBase): dataset_dir = "imagenet" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.preprocessed): with open(self.preprocessed, "rb") as f: preprocessed = pickle.load(f) train = preprocessed["train"] test = preprocessed["test"] else: text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = self.read_classnames(text_file) train = self.read_data(classnames, "train") # Follow standard practice to perform evaluation on the val set # Also used as the val set (so evaluate the last-step model) test = self.read_data(classnames, "val") preprocessed = {"train": train, "test": test} with open(self.preprocessed, "wb") as f: pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train = data["train"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) data = {"train": train} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) super().__init__(train_x=train, val=test, test=test) @staticmethod def read_classnames(text_file): """Return a dictionary containing key-value pairs of : . """ classnames = OrderedDict() with open(text_file, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split(" ") folder = line[0] classname = " ".join(line[1:]) classnames[folder] = classname return classnames def read_data(self, classnames, split_dir): split_dir = os.path.join(self.image_dir, split_dir) folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) items = [] for label, folder in enumerate(folders): imnames = listdir_nohidden(os.path.join(split_dir, folder)) classname = classnames[folder] for imname in imnames: impath = os.path.join(split_dir, folder, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/imagenet_a.py ================================================ import os from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden from .imagenet import ImageNet TO_BE_IGNORED = ["README.txt"] @DATASET_REGISTRY.register() class ImageNetA(DatasetBase): """ImageNet-A(dversarial). This dataset is used for testing only. """ dataset_dir = "imagenet-adversarial" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = ImageNet.read_classnames(text_file) data = self.read_data(classnames) super().__init__(train_x=data, test=data) def read_data(self, classnames): image_dir = self.image_dir folders = listdir_nohidden(image_dir, sort=True) folders = [f for f in folders if f not in TO_BE_IGNORED] items = [] for label, folder in enumerate(folders): imnames = listdir_nohidden(os.path.join(image_dir, folder)) classname = classnames[folder] for imname in imnames: impath = os.path.join(image_dir, folder, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/imagenet_r.py ================================================ import os from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden from .imagenet import ImageNet TO_BE_IGNORED = ["README.txt"] @DATASET_REGISTRY.register() class ImageNetR(DatasetBase): """ImageNet-R(endition). This dataset is used for testing only. """ dataset_dir = "imagenet-rendition" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = ImageNet.read_classnames(text_file) data = self.read_data(classnames) super().__init__(train_x=data, test=data) def read_data(self, classnames): image_dir = self.image_dir folders = listdir_nohidden(image_dir, sort=True) folders = [f for f in folders if f not in TO_BE_IGNORED] items = [] for label, folder in enumerate(folders): imnames = listdir_nohidden(os.path.join(image_dir, folder)) classname = classnames[folder] for imname in imnames: impath = os.path.join(image_dir, folder, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/imagenet_sketch.py ================================================ import os from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden from .imagenet import ImageNet @DATASET_REGISTRY.register() class ImageNetSketch(DatasetBase): """ImageNet-Sketch. This dataset is used for testing only. """ dataset_dir = "imagenet-sketch" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = ImageNet.read_classnames(text_file) data = self.read_data(classnames) super().__init__(train_x=data, test=data) def read_data(self, classnames): image_dir = self.image_dir folders = listdir_nohidden(image_dir, sort=True) items = [] for label, folder in enumerate(folders): imnames = listdir_nohidden(os.path.join(image_dir, folder)) classname = classnames[folder] for imname in imnames: impath = os.path.join(image_dir, folder, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/imagenetv2.py ================================================ import os from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden from .imagenet import ImageNet @DATASET_REGISTRY.register() class ImageNetV2(DatasetBase): """ImageNetV2. This dataset is used for testing only. """ dataset_dir = "imagenetv2" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) image_dir = "imagenetv2-matched-frequency-format-val" self.image_dir = os.path.join(self.dataset_dir, image_dir) text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = ImageNet.read_classnames(text_file) data = self.read_data(classnames) super().__init__(train_x=data, test=data) def read_data(self, classnames): image_dir = self.image_dir folders = list(classnames.keys()) items = [] for label in range(1000): class_dir = os.path.join(image_dir, str(label)) imnames = listdir_nohidden(class_dir) folder = folders[label] classname = classnames[folder] for imname in imnames: impath = os.path.join(class_dir, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/oxford_flowers.py ================================================ import os import pickle import random from scipy.io import loadmat from collections import defaultdict from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import read_json, mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class OxfordFlowers(DatasetBase): dataset_dir = "oxford_flowers" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "jpg") self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: train, val, test = self.read_data() OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self): tracker = defaultdict(list) label_file = loadmat(self.label_file)["labels"][0] for i, label in enumerate(label_file): imname = f"image_{str(i + 1).zfill(5)}.jpg" impath = os.path.join(self.image_dir, imname) label = int(label) tracker[label].append(impath) print("Splitting data into 50% train, 20% val, and 30% test") def _collate(ims, y, c): items = [] for im in ims: item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label items.append(item) return items lab2cname = read_json(self.lab2cname_file) train, val, test = [], [], [] for label, impaths in tracker.items(): random.shuffle(impaths) n_total = len(impaths) n_train = round(n_total * 0.5) n_val = round(n_total * 0.2) n_test = n_total - n_train - n_val assert n_train > 0 and n_val > 0 and n_test > 0 cname = lab2cname[str(label)] train.extend(_collate(impaths[:n_train], label, cname)) val.extend(_collate(impaths[n_train:n_train + n_val], label, cname)) test.extend(_collate(impaths[n_train + n_val:], label, cname)) return train, val, test ================================================ FILE: ProGrad.public/datasets/oxford_pets.py ================================================ import os import pickle import math import random from collections import defaultdict from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import read_json, write_json, mkdir_if_missing @DATASET_REGISTRY.register() class OxfordPets(DatasetBase): dataset_dir = "oxford_pets" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "images") self.anno_dir = os.path.join(self.dataset_dir, "annotations") self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = self.read_split(self.split_path, self.image_dir) else: trainval = self.read_data(split_file="trainval.txt") test = self.read_data(split_file="test.txt") train, val = self.split_trainval(trainval) self.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = self.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self, split_file): filepath = os.path.join(self.anno_dir, split_file) items = [] with open(filepath, "r") as f: lines = f.readlines() for line in lines: line = line.strip() imname, label, species, _ = line.split(" ") breed = imname.split("_")[:-1] breed = "_".join(breed) breed = breed.lower() imname += ".jpg" impath = os.path.join(self.image_dir, imname) label = int(label) - 1 # convert to 0-based index item = Datum(impath=impath, label=label, classname=breed) items.append(item) return items @staticmethod def split_trainval(trainval, p_val=0.2): p_trn = 1 - p_val print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") tracker = defaultdict(list) for idx, item in enumerate(trainval): label = item.label tracker[label].append(idx) train, val = [], [] for label, idxs in tracker.items(): n_val = round(len(idxs) * p_val) assert n_val > 0 random.shuffle(idxs) for n, idx in enumerate(idxs): item = trainval[idx] if n < n_val: val.append(item) else: train.append(item) return train, val @staticmethod def save_split(train, val, test, filepath, path_prefix): def _extract(items): out = [] for item in items: impath = item.impath label = item.label classname = item.classname impath = impath.replace(path_prefix, "") if impath.startswith("/"): impath = impath[1:] out.append((impath, label, classname)) return out train = _extract(train) val = _extract(val) test = _extract(test) split = {"train": train, "val": val, "test": test} write_json(split, filepath) print(f"Saved split to {filepath}") @staticmethod def read_split(filepath, path_prefix): def _convert(items): out = [] for impath, label, classname in items: impath = os.path.join(path_prefix, impath) item = Datum(impath=impath, label=int(label), classname=classname) out.append(item) return out print(f"Reading split from {filepath}") split = read_json(filepath) train = _convert(split["train"]) val = _convert(split["val"]) test = _convert(split["test"]) return train, val, test @staticmethod def subsample_classes(*args, subsample="all"): """Divide classes into two groups. The first group represents base classes while the second group represents new classes. Args: args: a list of datasets, e.g. train, val and test. subsample (str): what classes to subsample. """ assert subsample in ["all", "base", "new"] if subsample == "all": return args dataset = args[0] labels = set() for item in dataset: labels.add(item.label) labels = list(labels) labels.sort() n = len(labels) # Divide classes into two halves m = math.ceil(n / 2) print(f"SUBSAMPLE {subsample.upper()} CLASSES!") if subsample == "base": selected = labels[:m] # take the first half else: selected = labels[m:] # take the second half relabeler = {y: y_new for y_new, y in enumerate(selected)} output = [] for dataset in args: dataset_new = [] for item in dataset: if item.label not in selected: continue item_new = Datum(impath=item.impath, label=relabeler[item.label], classname=item.classname) dataset_new.append(item_new) output.append(dataset_new) return output ================================================ FILE: ProGrad.public/datasets/stanford_cars.py ================================================ import os import pickle from scipy.io import loadmat from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class StanfordCars(DatasetBase): dataset_dir = "stanford_cars" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) else: trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") trainval = self.read_data("cars_train", trainval_file, meta_file) test = self.read_data("cars_test", test_file, meta_file) train, val = OxfordPets.split_trainval(trainval) OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self, image_dir, anno_file, meta_file): anno_file = loadmat(anno_file)["annotations"][0] meta_file = loadmat(meta_file)["class_names"][0] items = [] for i in range(len(anno_file)): imname = anno_file[i]["fname"][0] impath = os.path.join(self.dataset_dir, image_dir, imname) label = anno_file[i]["class"][0, 0] label = int(label) - 1 # convert to 0-based index classname = meta_file[label][0] names = classname.split(" ") year = names.pop(-1) names.insert(0, year) classname = " ".join(names) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/sun397.py ================================================ import os import pickle from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class SUN397(DatasetBase): dataset_dir = "sun397" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "SUN397") self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: classnames = [] with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: lines = f.readlines() for line in lines: line = line.strip()[1:] # remove / classnames.append(line) cname2lab = {c: i for i, c in enumerate(classnames)} trainval = self.read_data(cname2lab, "Training_01.txt") test = self.read_data(cname2lab, "Testing_01.txt") train, val = OxfordPets.split_trainval(trainval) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self, cname2lab, text_file): text_file = os.path.join(self.dataset_dir, text_file) items = [] with open(text_file, "r") as f: lines = f.readlines() for line in lines: imname = line.strip()[1:] # remove / classname = os.path.dirname(imname) label = cname2lab[classname] impath = os.path.join(self.image_dir, imname) names = classname.split("/")[1:] # remove 1st letter names = names[::-1] # put words like indoor/outdoor at first classname = " ".join(names) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items ================================================ FILE: ProGrad.public/datasets/ucf101.py ================================================ import os import pickle import re from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import mkdir_if_missing from .oxford_pets import OxfordPets @DATASET_REGISTRY.register() class UCF101(DatasetBase): dataset_dir = "ucf101" def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") mkdir_if_missing(self.split_fewshot_dir) if os.path.exists(self.split_path): train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) else: cname2lab = {} filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") with open(filepath, "r") as f: lines = f.readlines() for line in lines: label, classname = line.strip().split(" ") label = int(label) - 1 # conver to 0-based index cname2lab[classname] = label trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") train, val = OxfordPets.split_trainval(trainval) OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) num_shots = cfg.DATASET.NUM_SHOTS if num_shots >= 1: seed = cfg.SEED preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") if os.path.exists(preprocessed): print( f"Loading preprocessed few-shot data from {preprocessed}") with open(preprocessed, "rb") as file: data = pickle.load(file) train, val = data["train"], data["val"] else: train = self.generate_fewshot_dataset(train, num_shots=num_shots) val = self.generate_fewshot_dataset(val, num_shots=min( num_shots, 4)) data = {"train": train, "val": val} print(f"Saving preprocessed few-shot data to {preprocessed}") with open(preprocessed, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) subsample = cfg.DATASET.SUBSAMPLE_CLASSES train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) super().__init__(train_x=train, val=val, test=test) def read_data(self, cname2lab, text_file): text_file = os.path.join(self.dataset_dir, text_file) items = [] with open(text_file, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split(" ")[0] # trainlist: filename, label action, filename = line.split("/") label = cname2lab[action] elements = re.findall("[A-Z][^A-Z]*", action) renamed_action = "_".join(elements) filename = filename.replace(".avi", ".jpg") impath = os.path.join(self.image_dir, renamed_action, filename) item = Datum(impath=impath, label=label, classname=renamed_action) items.append(item) return items ================================================ FILE: ProGrad.public/interpret_prompt.py ================================================ import os import sys import argparse import torch from clip.simple_tokenizer import SimpleTokenizer from clip import clip def load_clip_to_cpu(backbone_name="RN50"): url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") model = clip.build_model(state_dict or model.state_dict()) return model parser = argparse.ArgumentParser() parser.add_argument("fpath", type=str, help="Path to the learned prompt") parser.add_argument("topk", type=int, help="Select top-k similar words") args = parser.parse_args() fpath = args.fpath topk = args.topk assert os.path.exists(fpath) print(f"Return the top-{topk} matched words") tokenizer = SimpleTokenizer() clip_model = load_clip_to_cpu() token_embedding = clip_model.token_embedding.weight print(f"Size of token embedding: {token_embedding.shape}") prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"] ctx = prompt_learner["ctx"] ctx = ctx.float() print(f"Size of context: {ctx.shape}") if ctx.dim() == 2: # Generic context distance = torch.cdist(ctx, token_embedding) print(f"Size of distance matrix: {distance.shape}") sorted_idxs = torch.argsort(distance, dim=1) sorted_idxs = sorted_idxs[:, :topk] for m, idxs in enumerate(sorted_idxs): words = [tokenizer.decoder[idx.item()] for idx in idxs] dist = [f"{distance[m, idx].item():.4f}" for idx in idxs] print(f"{m+1}: {words} {dist}") elif ctx.dim() == 3: # Class-specific context raise NotImplementedError ================================================ FILE: ProGrad.public/lpclip/README.md ================================================ # Linear Probe CLIP To run linear probe baselines, make sure that your current working directory is `lpclip/`. Step 1: Extract Features using the CLIP Image Encoder ```bash sh feat_extractor.sh ``` Step 2: Train few-shot linear probe ```bash sh linear_probe.sh ``` We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep. Note: please pull the latest Dassl (version >= `606a2c6`). ================================================ FILE: ProGrad.public/lpclip/feat_extractor.py ================================================ import os, argparse import numpy as np import torch import sys sys.path.append(os.path.abspath("..")) from datasets.oxford_pets import OxfordPets from datasets.oxford_flowers import OxfordFlowers from datasets.fgvc_aircraft import FGVCAircraft from datasets.dtd import DescribableTextures from datasets.eurosat import EuroSAT from datasets.stanford_cars import StanfordCars from datasets.food101 import Food101 from datasets.sun397 import SUN397 from datasets.caltech101 import Caltech101 from datasets.ucf101 import UCF101 from datasets.imagenet import ImageNet from datasets.imagenetv2 import ImageNetV2 from datasets.imagenet_sketch import ImageNetSketch from datasets.imagenet_a import ImageNetA from datasets.imagenet_r import ImageNetR from dassl.utils import setup_logger, set_random_seed, collect_env_info from dassl.config import get_cfg_default from dassl.data.transforms import build_transform from dassl.data import DatasetWrapper import clip # import pdb; pdb.set_trace() def print_args(args, cfg): print("***************") print("** Arguments **") print("***************") optkeys = list(args.__dict__.keys()) optkeys.sort() for key in optkeys: print("{}: {}".format(key, args.__dict__[key])) print("************") print("** Config **") print("************") print(cfg) def reset_cfg(cfg, args): if args.root: cfg.DATASET.ROOT = args.root if args.output_dir: cfg.OUTPUT_DIR = args.output_dir if args.trainer: cfg.TRAINER.NAME = args.trainer if args.backbone: cfg.MODEL.BACKBONE.NAME = args.backbone if args.head: cfg.MODEL.HEAD.NAME = args.head def extend_cfg(cfg): """ Add new config variables. E.g. from yacs.config import CfgNode as CN cfg.TRAINER.MY_MODEL = CN() cfg.TRAINER.MY_MODEL.PARAM_A = 1. cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 cfg.TRAINER.MY_MODEL.PARAM_C = False """ from yacs.config import CfgNode as CN cfg.TRAINER.OURS = CN() cfg.TRAINER.OURS.N_CTX = 10 # number of context vectors cfg.TRAINER.OURS.CSC = False # class-specific context cfg.TRAINER.OURS.CTX_INIT = "" # initialize context vectors with given words cfg.TRAINER.OURS.WEIGHT_U = 0.1 # weight for the unsupervised loss cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new def setup_cfg(args): cfg = get_cfg_default() extend_cfg(cfg) # 1. From the dataset config file if args.dataset_config_file: cfg.merge_from_file(args.dataset_config_file) # 2. From the method config file if args.config_file: cfg.merge_from_file(args.config_file) # 3. From input arguments reset_cfg(cfg, args) cfg.freeze() return cfg def main(args): cfg = setup_cfg(args) if cfg.SEED >= 0: print("Setting fixed seed: {}".format(cfg.SEED)) set_random_seed(cfg.SEED) setup_logger(cfg.OUTPUT_DIR) if torch.cuda.is_available() and cfg.USE_CUDA: torch.backends.cudnn.benchmark = True print_args(args, cfg) print("Collecting env info ...") print("** System info **\n{}\n".format(collect_env_info())) ###################################### # Setup DataLoader ###################################### dataset = eval(cfg.DATASET.NAME)(cfg) if args.split == "train": dataset_input = dataset.train_x elif args.split == "val": dataset_input = dataset.val else: dataset_input = dataset.test tfm_train = build_transform(cfg, is_train=False) data_loader = torch.utils.data.DataLoader( DatasetWrapper(cfg, dataset_input, transform=tfm_train, is_train=False), batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, sampler=None, shuffle=False, num_workers=cfg.DATALOADER.NUM_WORKERS, drop_last=False, pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), ) ######################################## # Setup Network ######################################## clip_model, _ = clip.load("RN50", "cuda", jit=False) clip_model.eval() ################################################################################################################### # Start Feature Extractor feature_list = [] label_list = [] train_dataiter = iter(data_loader) for train_step in range(1, len(train_dataiter) + 1): batch = next(train_dataiter) data = batch["img"].cuda() feature = clip_model.visual(data) feature = feature.cpu() for idx in range(len(data)): feature_list.append(feature[idx].tolist()) label_list.extend(batch["label"].tolist()) save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME) os.makedirs(save_dir, exist_ok=True) save_filename = f"{args.split}" np.savez( os.path.join(save_dir, save_filename), feature_list=feature_list, label_list=label_list, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--root", type=str, default="", help="path to dataset") parser.add_argument("--output-dir", type=str, default="", help="output directory") parser.add_argument("--config-file", type=str, default="", help="path to config file") parser.add_argument( "--dataset-config-file", type=str, default="", help="path to config file for dataset setup", ) parser.add_argument("--num-shot", type=int, default=1, help="number of shots") parser.add_argument("--split", type=str, choices=["train", "val", "test"], help="which split") parser.add_argument("--trainer", type=str, default="", help="name of trainer") parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") parser.add_argument("--head", type=str, default="", help="name of head") parser.add_argument("--seed", type=int, default=-1, help="only positive value enables a fixed seed") parser.add_argument("--eval-only", action="store_true", help="evaluation only") args = parser.parse_args() main(args) ================================================ FILE: ProGrad.public/lpclip/feat_extractor.sh ================================================ # sh feat_extractor.sh DATA=/data1/CoOpData OUTPUT='/data1/CoOpData/clip_feat/' SEED=1 GPULIST=(0 1 2 3) GPUIDX=0 # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet # imagenet oxford_pets oxford_flowers stanford_cars food101 caltech101 for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r do for SPLIT in train val test do while true do sleep 10 let STATIDX=GPULIST[GPUIDX]+2 stat=$(gpustat | awk '{print $11}' | sed -n ${STATIDX}'p') if [ "$stat" -lt 20 ] then break fi let GPUIDX=(GPUIDX+1)%${#GPULIST[@]} echo $GPUIDX'N' done CUDA_VISIBLE_DEVICES=${GPULIST[${GPUIDX}]} python feat_extractor.py \ --split ${SPLIT} \ --root ${DATA} \ --seed ${SEED} \ --dataset-config-file ../configs/datasets/${DATASET}.yaml \ --config-file ../configs/trainers/CoOp/rn50_val.yaml \ --output-dir ${OUTPUT} \ --eval-only & sleep 10 done done ================================================ FILE: ProGrad.public/lpclip/linear_probe.py ================================================ import numpy as np import os from sklearn.linear_model import LogisticRegression import argparse parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="", help="path to dataset") parser.add_argument("--num_step", type=int, default=8, help="number of steps") parser.add_argument("--num_run", type=int, default=10, help="number of runs") parser.add_argument("--feature_dir", type=str, default="clip_feat", help="feature dir path") args = parser.parse_args() dataset = args.dataset dataset_path = os.path.join(f"{args.feature_dir}", dataset) train_file = np.load(os.path.join(dataset_path, "train.npz")) train_feature, train_label = train_file["feature_list"], train_file[ "label_list"] val_file = np.load(os.path.join(dataset_path, "val.npz")) val_feature, val_label = val_file["feature_list"], val_file["label_list"] test_file = np.load(os.path.join(dataset_path, "test.npz")) test_feature, test_label = test_file["feature_list"], test_file["label_list"] os.makedirs("report", exist_ok=True) val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4} # for num_shot in [1, 2, 4, 8, 16]: for num_shot in [4, 16]: test_acc_step_list = np.zeros([args.num_run, args.num_step]) for seed in range(1, args.num_run + 1): np.random.seed(seed) print( f"-- Seed: {seed} --------------------------------------------------------------" ) # Sampling all_label_list = np.unique(train_label) selected_idx_list = [] for label in all_label_list: label_collection = np.where(train_label == label)[0] selected_idx = np.random.choice(label_collection, size=num_shot, replace=False) selected_idx_list.extend(selected_idx) fewshot_train_feature = train_feature[selected_idx_list] fewshot_train_label = train_label[selected_idx_list] val_num_shot = val_shot_list[num_shot] val_selected_idx_list = [] for label in all_label_list: label_collection = np.where(val_label == label)[0] selected_idx = np.random.choice(label_collection, size=val_num_shot, replace=False) val_selected_idx_list.extend(selected_idx) fewshot_val_feature = val_feature[val_selected_idx_list] fewshot_val_label = val_label[val_selected_idx_list] # search initialization search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6] acc_list = [] for c_weight in search_list: clf = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_weight).fit(fewshot_train_feature, fewshot_train_label) pred = clf.predict(fewshot_val_feature) acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label) acc_list.append(acc_val) print(acc_list, flush=True) # binary search peak_idx = np.argmax(acc_list) c_peak = search_list[peak_idx] c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak def binary_search(c_left, c_right, seed, step, test_acc_step_list): clf_left = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_left).fit( fewshot_train_feature, fewshot_train_label) pred_left = clf_left.predict(fewshot_val_feature) acc_left = sum( pred_left == fewshot_val_label) / len(fewshot_val_label) print("Val accuracy (Left): {:.2f}".format(100 * acc_left), flush=True) clf_right = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_right).fit( fewshot_train_feature, fewshot_train_label) pred_right = clf_right.predict(fewshot_val_feature) acc_right = sum( pred_right == fewshot_val_label) / len(fewshot_val_label) print("Val accuracy (Right): {:.2f}".format(100 * acc_right), flush=True) # find maximum and update ranges if acc_left < acc_right: c_final = c_right clf_final = clf_right # range for the next step c_left = 0.5 * (np.log10(c_right) + np.log10(c_left)) c_right = np.log10(c_right) else: c_final = c_left clf_final = clf_left # range for the next step c_right = 0.5 * (np.log10(c_right) + np.log10(c_left)) c_left = np.log10(c_left) pred = clf_final.predict(test_feature) test_acc = 100 * sum(pred == test_label) / len(pred) print("Test Accuracy: {:.2f}".format(test_acc), flush=True) test_acc_step_list[seed - 1, step] = test_acc saveline = "{}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format( dataset, seed, num_shot, c_final, test_acc) with open( "./report/{}_s{}r{}_details.txt".format( 'clip_feat', args.num_step, args.num_run), "a+", ) as writer: writer.write(saveline) return ( np.power(10, c_left), np.power(10, c_right), seed, step, test_acc_step_list, ) for step in range(args.num_step): print( f"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}", flush=True, ) c_left, c_right, seed, step, test_acc_step_list = binary_search( c_left, c_right, seed, step, test_acc_step_list) # save results of last step test_acc_list = test_acc_step_list[:, -1] acc_mean = np.mean(test_acc_list) acc_std = np.std(test_acc_list) save_line = "{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format( dataset, num_shot, acc_mean, acc_std) print(save_line, flush=True) with open( "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step, args.num_run), "a+", ) as writer: writer.write(save_line) ================================================ FILE: ProGrad.public/lpclip/linear_probe.sh ================================================ feature_dir=/data1/CoOpData/clip_feat/ # ImageNet OxfordPets OxfordFlowers StanfordCars Food101 Caltech101 for DATASET in ImageNet do python linear_probe.py \ --dataset ${DATASET} \ --feature_dir ${feature_dir} \ --num_step 8 \ --num_run 3 done ================================================ FILE: ProGrad.public/lpclip/linear_probe_transfer.py ================================================ import numpy as np import os from sklearn.linear_model import LogisticRegression import argparse parser = argparse.ArgumentParser() # parser.add_argument("--train_dataset", # type=str, # default="", # help="path to train dataset") # parser.add_argument("--test_dataset", # type=str, # default="", # help="path to test dataset") parser.add_argument("--num_step", type=int, default=8, help="number of steps") parser.add_argument("--num_run", type=int, default=10, help="number of runs") parser.add_argument("--feature_dir", type=str, default="/data1/CoOpData/clip_feat/", help="feature dir path") args = parser.parse_args() train_dataset = 'ImageNet' train_dataset_path = os.path.join(f"{args.feature_dir}", train_dataset) test_datasets = ['ImageNetV2', 'ImageNetSketch', 'ImageNetR', 'ImageNetA'] test_dataset_paths = [ os.path.join(f"{args.feature_dir}", test_dataset) for test_dataset in test_datasets ] train_file = np.load(os.path.join(train_dataset_path, "train.npz")) train_feature, train_label = train_file["feature_list"], train_file[ "label_list"] val_file = np.load(os.path.join(train_dataset_path, "val.npz")) val_feature, val_label = val_file["feature_list"], val_file["label_list"] test_files = [ np.load(os.path.join(test_dataset_path, "test.npz")) for test_dataset_path in test_dataset_paths ] test_features, test_labels = [ test_file["feature_list"] for test_file in test_files ], [test_file["label_list"] for test_file in test_files] os.makedirs("report", exist_ok=True) val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4} # for num_shot in [1, 2, 4, 8, 16]: for num_shot in [16]: test_acc_step_list = np.zeros( [len(test_datasets), args.num_run, args.num_step]) for seed in range(1, args.num_run + 1): np.random.seed(seed) print( f"-- Seed: {seed} --------------------------------------------------------------" ) # Sampling all_label_list = np.unique(train_label) selected_idx_list = [] for label in all_label_list: label_collection = np.where(train_label == label)[0] selected_idx = np.random.choice(label_collection, size=num_shot, replace=False) selected_idx_list.extend(selected_idx) fewshot_train_feature = train_feature[selected_idx_list] fewshot_train_label = train_label[selected_idx_list] val_num_shot = val_shot_list[num_shot] val_selected_idx_list = [] for label in all_label_list: label_collection = np.where(val_label == label)[0] selected_idx = np.random.choice(label_collection, size=val_num_shot, replace=False) val_selected_idx_list.extend(selected_idx) fewshot_val_feature = val_feature[val_selected_idx_list] fewshot_val_label = val_label[val_selected_idx_list] # search initialization search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6] acc_list = [] for c_weight in search_list: clf = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_weight).fit(fewshot_train_feature, fewshot_train_label) pred = clf.predict(fewshot_val_feature) acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label) acc_list.append(acc_val) print(acc_list, flush=True) # binary search peak_idx = np.argmax(acc_list) c_peak = search_list[peak_idx] c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak def binary_search(c_left, c_right, seed, step, test_acc_step_list): clf_left = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_left).fit( fewshot_train_feature, fewshot_train_label) pred_left = clf_left.predict(fewshot_val_feature) acc_left = sum( pred_left == fewshot_val_label) / len(fewshot_val_label) print("Val accuracy (Left): {:.2f}".format(100 * acc_left), flush=True) clf_right = LogisticRegression(solver="lbfgs", max_iter=1000, penalty="l2", C=c_right).fit( fewshot_train_feature, fewshot_train_label) pred_right = clf_right.predict(fewshot_val_feature) acc_right = sum( pred_right == fewshot_val_label) / len(fewshot_val_label) print("Val accuracy (Right): {:.2f}".format(100 * acc_right), flush=True) # find maximum and update ranges if acc_left < acc_right: c_final = c_right clf_final = clf_right # range for the next step c_left = 0.5 * (np.log10(c_right) + np.log10(c_left)) c_right = np.log10(c_right) else: c_final = c_left clf_final = clf_left # range for the next step c_right = 0.5 * (np.log10(c_right) + np.log10(c_left)) c_left = np.log10(c_left) for i, (test_feature, test_label, test_dataset) in enumerate( zip(test_features, test_labels, test_datasets)): pred = clf_final.predict(test_feature) test_acc = 100 * sum(pred == test_label) / len(pred) print("Test Accuracy: {:.2f}".format(test_acc), flush=True) test_acc_step_list[i, seed - 1, step] = test_acc saveline = "{}, {}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format( train_dataset, test_dataset, seed, num_shot, c_final, test_acc) with open( "./report/{}_s{}r{}_details.txt".format( 'clip_feat', args.num_step, args.num_run), "a+", ) as writer: writer.write(saveline) return ( np.power(10, c_left), np.power(10, c_right), seed, step, test_acc_step_list, ) for step in range(args.num_step): print( f"{train_dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}", flush=True, ) c_left, c_right, seed, step, test_acc_step_list = binary_search( c_left, c_right, seed, step, test_acc_step_list) # save results of last step test_acc_list = test_acc_step_list[:, :, -1] acc_mean = np.mean(test_acc_list, dim=-1) acc_std = np.std(test_acc_list, dim=-1) for i in range(len(test_datasets)): save_line = "{}, {}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format( train_dataset, test_datasets[i], num_shot, acc_mean[i], acc_std[i]) print(save_line, flush=True) with open( "./report/{}_s{}r{}.txt".format('clip_feat', args.num_step, args.num_run), "a+", ) as writer: writer.write(save_line) ================================================ FILE: ProGrad.public/parse_test_res.py ================================================ """ Goal --- 1. Read test results from log.txt files 2. Compute mean and std across different folders (seeds) Usage --- Assume the output files are saved under output/my_experiment, which contains results of different seeds, e.g., my_experiment/ seed1/ log.txt seed2/ log.txt seed3/ log.txt Run the following command from the root directory: $ python tools/parse_test_res.py output/my_experiment Add --ci95 to the argument if you wanna get 95% confidence interval instead of standard deviation: $ python tools/parse_test_res.py output/my_experiment --ci95 If my_experiment/ has the following structure, my_experiment/ exp-1/ seed1/ log.txt ... seed2/ log.txt ... seed3/ log.txt ... exp-2/ ... exp-3/ ... Run $ python tools/parse_test_res.py output/my_experiment --multi-exp """ import re import numpy as np import os.path as osp import argparse from collections import OrderedDict, defaultdict from dassl.utils import check_isfile, listdir_nohidden def compute_ci95(res): return 1.96 * np.std(res) / np.sqrt(len(res)) def parse_function(*metrics, directory="", args=None, end_signal=None): print(f"Parsing files in {directory}") subdirs = listdir_nohidden(directory, sort=True) outputs = [] for subdir in subdirs: fpath = osp.join(directory, subdir, "log.txt") assert check_isfile(fpath) good_to_go = False output = OrderedDict() with open(fpath, "r") as f: lines = f.readlines() for line in lines: line = line.strip() if line == end_signal: good_to_go = True for metric in metrics: match = metric["regex"].search(line) if match and good_to_go: if "file" not in output: output["file"] = fpath num = float(match.group(1)) name = metric["name"] output[name] = num if output: outputs.append(output) assert len(outputs) > 0, f"Nothing found in {directory}" metrics_results = defaultdict(list) for output in outputs: msg = "" for key, value in output.items(): if isinstance(value, float): msg += f"{key}: {value:.2f}%. " else: msg += f"{key}: {value}. " if key != "file": metrics_results[key].append(value) print(msg) output_results = OrderedDict() print("===") print(f"Summary of directory: {directory}") for key, values in metrics_results.items(): avg = np.mean(values) std = compute_ci95(values) if args.ci95 else np.std(values) print(f"* {key}: {avg:.2f}% +- {std:.2f}%") output_results[key] = avg print("===") return output_results def main(args, end_signal): metric = { "name": args.keyword, "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"), } if args.multi_exp: final_results = defaultdict(list) for directory in listdir_nohidden(args.directory, sort=True): directory = osp.join(args.directory, directory) results = parse_function(metric, directory=directory, args=args, end_signal=end_signal) for key, value in results.items(): final_results[key].append(value) print("Average performance") for key, values in final_results.items(): avg = np.mean(values) print(f"* {key}: {avg:.2f}%") else: parse_function(metric, directory=args.directory, args=args, end_signal=end_signal) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("directory", type=str, help="path to directory") parser.add_argument("--ci95", action="store_true", help=r"compute 95\% confidence interval") parser.add_argument("--test-log", action="store_true", help="parse test-only logs") parser.add_argument("--multi-exp", action="store_true", help="parse multiple experiments") parser.add_argument("--keyword", default="accuracy", type=str, help="which keyword to extract") args = parser.parse_args() end_signal = "Finished training" if args.test_log: end_signal = "=> result" main(args, end_signal) ================================================ FILE: ProGrad.public/requirements.txt ================================================ ftfy regex tqdm ================================================ FILE: ProGrad.public/scripts/base2new_test_main.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=CoOp DATASET=$1 CFG=rn50_ep100 # config file CTP=end # class token position (end or middle) NCTX=16 # number of context tokens SHOTS=4 # number of shots (1, 2, 4, 8, 16) CSC=False # class-specific context (False or True) LOADEP=100 SUB=new for SEED in 1 2 3 do COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} MODEL_DIR=output/base2new/train_base/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ --model-dir ${MODEL_DIR} \ --load-epoch ${LOADEP} \ --eval-only \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ DATASET.SUBSAMPLE_CLASSES ${SUB} fi done ================================================ FILE: ProGrad.public/scripts/base2new_test_prograd.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=ProGrad DATASET=$1 CFG=rn50_ep100 # config file CTP=end # class token position (end or middle) NCTX=16 # number of context tokens SHOTS=4 # number of shots (1, 2, 4, 8, 16) CSC=False # class-specific context (False or True) LAMBDA=1.0 LOADEP=100 SUB=new for SEED in 1 2 3 do COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} MODEL_DIR=output/base2new/train_base/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ --model-dir ${MODEL_DIR} \ --load-epoch ${LOADEP} \ --eval-only \ LOSS.LAMBDA ${LAMBDA} \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ DATASET.SUBSAMPLE_CLASSES ${SUB} fi done ================================================ FILE: ProGrad.public/scripts/base2new_train_main.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=CoOp DATASET=$1 CFG=rn50_ep100 # config file CTP=end # class token position (end or middle) NCTX=16 # number of context tokens SHOTS=4 # number of shots (1, 2, 4, 8, 16) CSC=False # class-specific context (False or True) for SEED in 1 2 3 do DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ DATASET.SUBSAMPLE_CLASSES base fi done ================================================ FILE: ProGrad.public/scripts/base2new_train_prograd.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=ProGrad DATASET=$1 CFG=rn50_ep100 # config file CTP=end # class token position (end or middle) NCTX=16 # number of context tokens SHOTS=4 # number of shots (1, 2, 4, 8, 16) CSC=False # class-specific context (False or True) LAMBDA=1.0 for SEED in 1 2 3 do DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ LOSS.LAMBDA ${LAMBDA} \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ DATASET.SUBSAMPLE_CLASSES base fi done ================================================ FILE: ProGrad.public/scripts/eval.sh ================================================ #!/bin/bash cd .. # custom config DATA=/path/to/datasets TRAINER=CoOp SHOTS=4 NCTX=16 CSC=False CTP=end DATASET=$1 CFG=$2 for SEED in 1 2 3 do python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ --load-epoch 50 \ --eval-only \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} done ================================================ FILE: ProGrad.public/scripts/main.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=CoOp DATASET=$1 CFG=$2 # config file CTP=$3 # class token position (end or middle) NCTX=$4 # number of context tokens SHOTS=$5 # number of shots (1, 2, 4, 8, 16) CSC=$6 # class-specific context (False or True) for SEED in 1 2 3 do DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} fi done ================================================ FILE: ProGrad.public/scripts/prograd.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData/ TRAINER=ProGrad DATASET=$1 CFG=$2 # config file CTP=$3 # class token position (end or middle) NCTX=$4 # number of context tokens SHOTS=$5 # number of shots (1, 2, 4, 8, 16) CSC=$6 # class-specific context (False or True) LAMBDA=1.0 for SEED in 1 2 3 do DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else echo "Run this job and save the output to ${DIR}" python train.py \ --root ${DATA} \ --seed ${SEED} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ --output-dir ${DIR} \ LOSS.LAMBDA ${LAMBDA} \ TRAINER.COOP.N_CTX ${NCTX} \ TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} fi done ================================================ FILE: ProGrad.public/scripts/zeroshot.sh ================================================ #!/bin/bash cd .. # custom config DATA=/data1/CoOpData TRAINER=ZeroshotCLIP DATASET=$1 CFG=$2 # rn50, rn101, vit_b32 or vit_b16 python train.py \ --root ${DATA} \ --trainer ${TRAINER} \ --dataset-config-file configs/datasets/${DATASET}.yaml \ --config-file configs/trainers/CoOp/${CFG}.yaml \ --output-dir output/${TRAINER}/${CFG}/${DATASET} \ --eval-only ================================================ FILE: ProGrad.public/train.py ================================================ import argparse import torch import time import os from dassl.utils import setup_logger, set_random_seed, collect_env_info from dassl.config import get_cfg_default from dassl.engine import build_trainer # custom import datasets.oxford_pets import datasets.oxford_flowers import datasets.fgvc_aircraft import datasets.dtd import datasets.eurosat import datasets.stanford_cars import datasets.food101 import datasets.sun397 import datasets.caltech101 import datasets.ucf101 import datasets.imagenet import datasets.imagenet_sketch import datasets.imagenetv2 import datasets.imagenet_a import datasets.imagenet_r import trainers.coop import trainers.cocoop import trainers.zsclip import trainers.prograd def print_args(args, cfg): print("***************") print("** Arguments **") print("***************") optkeys = list(args.__dict__.keys()) optkeys.sort() for key in optkeys: print("{}: {}".format(key, args.__dict__[key])) print("************") print("** Config **") print("************") print(cfg) def reset_cfg(cfg, args): if args.root: cfg.DATASET.ROOT = args.root if args.output_dir: cfg.OUTPUT_DIR = args.output_dir if args.resume: cfg.RESUME = args.resume if args.seed: cfg.SEED = args.seed if args.source_domains: cfg.DATASET.SOURCE_DOMAINS = args.source_domains if args.target_domains: cfg.DATASET.TARGET_DOMAINS = args.target_domains if args.transforms: cfg.INPUT.TRANSFORMS = args.transforms if args.trainer: cfg.TRAINER.NAME = args.trainer if args.backbone: cfg.MODEL.BACKBONE.NAME = args.backbone if args.head: cfg.MODEL.HEAD.NAME = args.head def extend_cfg(cfg): """ Add new config variables. E.g. from yacs.config import CfgNode as CN cfg.TRAINER.MY_MODEL = CN() cfg.TRAINER.MY_MODEL.PARAM_A = 1. cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 cfg.TRAINER.MY_MODEL.PARAM_C = False """ from yacs.config import CfgNode as CN cfg.TRAINER.COOP = CN() cfg.TRAINER.COOP.ALPHA = 1.0 cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors cfg.TRAINER.COOP.CSC = False # class-specific context cfg.TRAINER.COOP.CTX_INIT = False # initialization words cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' cfg.TRAINER.COCOOP = CN() cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors cfg.TRAINER.COCOOP.CTX_INIT = False # initialization words cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new """ Add new config """ cfg.LOSS = CN() cfg.LOSS.GM = False cfg.LOSS.NAME = "" cfg.LOSS.ALPHA = 0. cfg.LOSS.T = 1. cfg.LOSS.LAMBDA = 1. def setup_cfg(args): cfg = get_cfg_default() extend_cfg(cfg) # 1. From the dataset config file if args.dataset_config_file: cfg.merge_from_file(args.dataset_config_file) # 2. From the method config file if args.config_file: cfg.merge_from_file(args.config_file) # 3. From input arguments reset_cfg(cfg, args) # 4. From optional input arguments cfg.merge_from_list(args.opts) cfg.freeze() return cfg def main(args): cfg = setup_cfg(args) if cfg.SEED >= 0: print("Setting fixed seed: {}".format(cfg.SEED)) set_random_seed(cfg.SEED) setup_logger(cfg.OUTPUT_DIR) if torch.cuda.is_available() and cfg.USE_CUDA: torch.backends.cudnn.benchmark = True print_args(args, cfg) print("Collecting env info ...") print("** System info **\n{}\n".format(collect_env_info())) trainer = build_trainer(cfg) if args.eval_only: trainer.load_model(args.model_dir, epoch=args.load_epoch) trainer.test() return if not args.no_train: trainer.train() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--root", type=str, default="", help="path to dataset") parser.add_argument("--output-dir", type=str, default="", help="output directory") parser.add_argument( "--resume", type=str, default="", help="checkpoint directory (from which the training resumes)", ) parser.add_argument("--seed", type=int, default=-1, help="only positive value enables a fixed seed") parser.add_argument("--source-domains", type=str, nargs="+", help="source domains for DA/DG") parser.add_argument("--target-domains", type=str, nargs="+", help="target domains for DA/DG") parser.add_argument("--transforms", type=str, nargs="+", help="data augmentation methods") parser.add_argument("--config-file", type=str, default="", help="path to config file") parser.add_argument( "--dataset-config-file", type=str, default="", help="path to config file for dataset setup", ) parser.add_argument("--trainer", type=str, default="", help="name of trainer") parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") parser.add_argument("--head", type=str, default="", help="name of head") parser.add_argument("--eval-only", action="store_true", help="evaluation only") parser.add_argument( "--model-dir", type=str, default="", help="load model from this directory for eval-only mode", ) parser.add_argument("--load-epoch", type=int, help="load model weights at this epoch for evaluation") parser.add_argument("--no-train", action="store_true", help="do not call trainer.train()") parser.add_argument( "opts", default=None, nargs=argparse.REMAINDER, help="modify config options using the command-line", ) args = parser.parse_args() main(args) ================================================ FILE: ProGrad.public/trainers/__init__.py ================================================ ================================================ FILE: ProGrad.public/trainers/cocoop.py ================================================ import os.path as osp from collections import OrderedDict import math import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer _tokenizer = _Tokenizer() def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") model = clip.build_model(state_dict or model.state_dict()) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class PromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.COCOOP.N_CTX ctx_init = cfg.TRAINER.COCOOP.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] vis_dim = clip_model.visual.output_dim clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" if ctx_init: ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME] ctx_init = ctx_init.replace(" {}.", "") ctx_init = ctx_init.replace("_", " ") prompt_n_ctx = len(ctx_init.split(" ")) assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})" prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype) ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 + prompt_n_ctx, :] prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx)) prompt_prefix = f"{prompt_prefix} {ctx_init}" else: # random initialization ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") self.ctx = nn.Parameter(ctx_vectors) self.meta_net = nn.Sequential( OrderedDict([("linear1", nn.Linear(vis_dim, vis_dim // 16)), ("relu", nn.ReLU(inplace=True)), ("linear2", nn.Linear(vis_dim // 16, ctx_dim))])) if cfg.TRAINER.COCOOP.PREC == "fp16": self.meta_net.half() classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type( dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens def construct_prompts(self, ctx, prefix, suffix, label=None): # dim0 is either batch_size (during training) or n_cls (during testing) # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) if label is not None: prefix = prefix[label] suffix = suffix[label] prompts = torch.cat( [ prefix, # (dim0, 1, dim) ctx, # (dim0, n_ctx, dim) suffix, # (dim0, *, dim) ], dim=1, ) return prompts def forward(self, im_features): prefix = self.token_prefix suffix = self.token_suffix ctx = self.ctx # (n_ctx, ctx_dim) bias = self.meta_net(im_features) # (batch, ctx_dim) bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) # Use instance-conditioned context tokens for all classes prompts = [] for ctx_shifted_i in ctx_shifted: ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim) prompts.append(pts_i) prompts = torch.stack(prompts) return prompts CUSTOM_TEMPLATES = { # "OxfordPets": "a photo of a {}, a type of pet.", "OxfordPets": "a type of pet, a photo of a {}.", # "OxfordFlowers": "a photo of a {}, a type of flower.", "OxfordFlowers": "a type of flower, a photo of a {}.", "FGVCAircraft": "a type of aircraft, a photo of a {}.", "DescribableTextures": "a texture of {}.", "EuroSAT": "a centered satellite photo of {}.", "StanfordCars": "a photo of a {}.", # "Food101": "a photo of {}, a type of food.", "Food101": "a type of food, a photo of {}.", "SUN397": "a photo of a {}.", "Caltech101": "a photo of a {}.", "UCF101": "a photo of a person doing {}.", "ImageNet": "a photo of a {}.", "ImageNetSketch": "a photo of a {}.", "ImageNetV2": "a photo of a {}.", "ImageNetA": "a photo of a {}.", "ImageNetR": "a photo of a {}.", } class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.prompt_learner = PromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale self.dtype = clip_model.dtype def forward(self, image, label=None): tokenized_prompts = self.tokenized_prompts logit_scale = self.logit_scale.exp() image_features = self.image_encoder(image.type(self.dtype)) image_features = image_features / image_features.norm(dim=-1, keepdim=True) prompts = self.prompt_learner(image_features) logits = [] for pts_i, imf_i in zip(prompts, image_features): text_features = self.text_encoder(pts_i, tokenized_prompts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) l_i = logit_scale * imf_i @ text_features.t() logits.append(l_i) logits = torch.stack(logits) if self.prompt_learner.training: return F.cross_entropy(logits, label) return logits @TRAINER_REGISTRY.register() class CoCoOp(TrainerX): def check_cfg(self, cfg): assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in both the image and the text encoder") name_to_update = "prompt_learner" for name, param in self.model.named_parameters(): if name_to_update not in name: param.requires_grad_(False) # Double check enabled = set() for name, param in self.model.named_parameters(): if param.requires_grad: enabled.add(name) print(f"Parameters to be updated: {enabled}") if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) # NOTE: only give prompt_learner to the optimizer self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.scaler = GradScaler( ) if cfg.TRAINER.COCOOP.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() if device_count > 1: print( f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" ) self.model = nn.DataParallel(self.model) def forward_backward(self, batch): image, label = self.parse_batch_train(batch) model = self.model optim = self.optim scaler = self.scaler prec = self.cfg.TRAINER.COCOOP.PREC if prec == "amp": with autocast(): loss = model(image, label) optim.zero_grad() scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss = model(image, label) optim.zero_grad() loss.backward() optim.step() loss_summary = {"loss": loss.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print( "Note that load_model() is skipped as no pretrained model is given" ) return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError( 'Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors if "token_prefix" in state_dict: del state_dict["token_prefix"] if "token_suffix" in state_dict: del state_dict["token_suffix"] print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) ================================================ FILE: ProGrad.public/trainers/coop.py ================================================ import os.path as osp import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer _tokenizer = _Tokenizer() def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") model = clip.build_model(state_dict or model.state_dict()) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class PromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.COOP.N_CTX ctx_init = cfg.TRAINER.COOP.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" # if ctx_init: # # use given words to initialize context vectors # ctx_init = ctx_init.replace("_", " ") # n_ctx = len(ctx_init.split(" ")) # prompt = clip.tokenize(ctx_init) # with torch.no_grad(): # embedding = clip_model.token_embedding(prompt).type(dtype) # ctx_vectors = embedding[0, 1:1 + n_ctx, :] # prompt_prefix = ctx_init if ctx_init: ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME] ctx_init = ctx_init.replace(" {}.", "") ctx_init = ctx_init.replace("_", " ") prompt_n_ctx = len(ctx_init.split(" ")) assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})" prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype) ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 + prompt_n_ctx, :] prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx)) prompt_prefix = f"{prompt_prefix} {ctx_init}" else: # random initialization if cfg.TRAINER.COOP.CSC: print("Initializing class-specific contexts") ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) else: print("Initializing a generic context") ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") self.ctx = nn.Parameter(ctx_vectors) # to be optimized classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type( dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION def forward(self): ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix if self.class_token_position == "end": prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) suffix, # (n_cls, *, dim) ], dim=1, ) elif self.class_token_position == "middle": half_n_ctx = self.n_ctx // 2 prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i:i + 1, :, :] class_i = suffix[i:i + 1, :name_len, :] suffix_i = suffix[i:i + 1, name_len:, :] ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :] ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) ctx_i_half1, # (1, n_ctx//2, dim) class_i, # (1, name_len, dim) ctx_i_half2, # (1, n_ctx//2, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) elif self.class_token_position == "front": prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i:i + 1, :, :] class_i = suffix[i:i + 1, :name_len, :] suffix_i = suffix[i:i + 1, name_len:, :] ctx_i = ctx[i:i + 1, :, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) class_i, # (1, name_len, dim) ctx_i, # (1, n_ctx, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) else: raise ValueError return prompts CUSTOM_TEMPLATES = { # "OxfordPets": "a photo of a {}, a type of pet.", "OxfordPets": "a type of pet, a photo of a {}.", # "OxfordFlowers": "a photo of a {}, a type of flower.", "OxfordFlowers": "a type of flower, a photo of a {}.", "FGVCAircraft": "a type of aircraft, a photo of a {}.", "DescribableTextures": "a texture of {}.", "EuroSAT": "a centered satellite photo of {}.", "StanfordCars": "a photo of a {}.", # "Food101": "a photo of {}, a type of food.", "Food101": "a type of food, a photo of {}.", "SUN397": "a photo of a {}.", "Caltech101": "a photo of a {}.", "UCF101": "a photo of a person doing {}.", "ImageNet": "a photo of a {}.", "ImageNetSketch": "a photo of a {}.", "ImageNetV2": "a photo of a {}.", "ImageNetA": "a photo of a {}.", "ImageNetR": "a photo of a {}.", } class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.prompt_learner = PromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale self.dtype = clip_model.dtype def forward(self, image): image_features = self.image_encoder(image.type(self.dtype)) prompts = self.prompt_learner() tokenized_prompts = self.tokenized_prompts text_features = self.text_encoder(prompts, tokenized_prompts) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logit_scale = self.logit_scale.exp() logits = logit_scale * image_features @ text_features.t() return logits @TRAINER_REGISTRY.register() class CoOp(TrainerX): """Context Optimization (CoOp). Learning to Prompt for Vision-Language Models https://arxiv.org/abs/2109.01134 """ def check_cfg(self, cfg): assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in both the image and the text encoder") for name, param in self.model.named_parameters(): if "prompt_learner" not in name: param.requires_grad_(False) if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) # NOTE: only give prompt_learner to the optimizer self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() if device_count > 1: print( f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" ) self.model = nn.DataParallel(self.model) def forward_backward(self, batch): image, label = self.parse_batch_train(batch) prec = self.cfg.TRAINER.COOP.PREC if prec == "amp": with autocast(): output = self.model(image) loss = F.cross_entropy(output, label) self.optim.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optim) self.scaler.update() else: output = self.model(image) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print( "Note that load_model() is skipped as no pretrained model is given" ) return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError( 'Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors if "token_prefix" in state_dict: del state_dict["token_prefix"] if "token_suffix" in state_dict: del state_dict["token_suffix"] print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) ================================================ FILE: ProGrad.public/trainers/imagenet_templates.py ================================================ # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb IMAGENET_TEMPLATES = [ "a bad photo of a {}.", "a photo of many {}.", "a sculpture of a {}.", "a photo of the hard to see {}.", "a low resolution photo of the {}.", "a rendering of a {}.", "graffiti of a {}.", "a bad photo of the {}.", "a cropped photo of the {}.", "a tattoo of a {}.", "the embroidered {}.", "a photo of a hard to see {}.", "a bright photo of a {}.", "a photo of a clean {}.", "a photo of a dirty {}.", "a dark photo of the {}.", "a drawing of a {}.", "a photo of my {}.", "the plastic {}.", "a photo of the cool {}.", "a close-up photo of a {}.", "a black and white photo of the {}.", "a painting of the {}.", "a painting of a {}.", "a pixelated photo of the {}.", "a sculpture of the {}.", "a bright photo of the {}.", "a cropped photo of a {}.", "a plastic {}.", "a photo of the dirty {}.", "a jpeg corrupted photo of a {}.", "a blurry photo of the {}.", "a photo of the {}.", "a good photo of the {}.", "a rendering of the {}.", "a {} in a video game.", "a photo of one {}.", "a doodle of a {}.", "a close-up photo of the {}.", "a photo of a {}.", "the origami {}.", "the {} in a video game.", "a sketch of a {}.", "a doodle of the {}.", "a origami {}.", "a low resolution photo of a {}.", "the toy {}.", "a rendition of the {}.", "a photo of the clean {}.", "a photo of a large {}.", "a rendition of a {}.", "a photo of a nice {}.", "a photo of a weird {}.", "a blurry photo of a {}.", "a cartoon {}.", "art of a {}.", "a sketch of the {}.", "a embroidered {}.", "a pixelated photo of a {}.", "itap of the {}.", "a jpeg corrupted photo of the {}.", "a good photo of a {}.", "a plushie {}.", "a photo of the nice {}.", "a photo of the small {}.", "a photo of the weird {}.", "the cartoon {}.", "art of the {}.", "a drawing of the {}.", "a photo of the large {}.", "a black and white photo of a {}.", "the plushie {}.", "a dark photo of a {}.", "itap of a {}.", "graffiti of the {}.", "a toy {}.", "itap of my {}.", "a photo of a cool {}.", "a photo of a small {}.", "a tattoo of the {}.", ] IMAGENET_TEMPLATES_SELECT = [ "itap of a {}.", "a bad photo of the {}.", "a origami {}.", "a photo of the large {}.", "a {} in a video game.", "art of the {}.", "a photo of the small {}.", ] ================================================ FILE: ProGrad.public/trainers/prograd.py ================================================ import os.path as osp import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer from torch.nn.modules.loss import _Loss from tqdm import tqdm import json _tokenizer = _Tokenizer() def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") model = clip.build_model(state_dict or model.state_dict()) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class PromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.COOP.N_CTX ctx_init = cfg.TRAINER.COOP.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" if ctx_init: ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME] ctx_init = ctx_init.replace(" {}.", "") ctx_init = ctx_init.replace("_", " ") prompt_n_ctx = len(ctx_init.split(" ")) assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})" prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype) ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 + prompt_n_ctx, :] prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx)) prompt_prefix = f"{prompt_prefix} {ctx_init}" else: # random initialization if cfg.TRAINER.COOP.CSC: print("Initializing class-specific contexts") ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) else: print("Initializing a generic context") ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") self.ctx = nn.Parameter(ctx_vectors) # to be optimized classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type( dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION self.name_lens = name_lens def forward(self): ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix if self.class_token_position == "end": prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) suffix, # (n_cls, *, dim) ], dim=1, ) elif self.class_token_position == "middle": half_n_ctx = n_ctx // 2 prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i:i + 1, :, :] class_i = suffix[i:i + 1, :name_len, :] suffix_i = suffix[i:i + 1, name_len:, :] ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :] ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) ctx_i_half1, # (1, n_ctx//2, dim) class_i, # (1, name_len, dim) ctx_i_half2, # (1, n_ctx//2, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) elif self.class_token_position == "front": prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i:i + 1, :, :] class_i = suffix[i:i + 1, :name_len, :] suffix_i = suffix[i:i + 1, name_len:, :] ctx_i = ctx[i:i + 1, :, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) class_i, # (1, name_len, dim) ctx_i, # (1, n_ctx, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) else: raise ValueError return prompts CUSTOM_TEMPLATES = { "OxfordPets": "a type of pet, a photo of a {}.", "OxfordFlowers": "a type of flower, a photo of a {}.", "FGVCAircraft": "a type of aircraft, a photo of a {}.", "DescribableTextures": "a texture of {}.", "EuroSAT": "a centered satellite photo of {}.", "StanfordCars": "a photo of a {}.", "Food101": "a type of food, a photo of {}.", "SUN397": "a photo of a {}.", "Caltech101": "a photo of a {}.", "UCF101": "a photo of a person doing {}.", "ImageNet": "a photo of a {}.", "ImageNetSketch": "a photo of a {}.", "ImageNetV2": "a photo of a {}.", "ImageNetA": "a photo of a {}.", "ImageNetR": "a photo of a {}.", } class CLIP(nn.Module): def __init__(self, cfg, classnames): super().__init__() print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) clip_model.float() temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] prompts = [temp.format(c.replace("_", " ")) for c in classnames] print(f"Prompts: {prompts}") prompts = torch.cat([clip.tokenize(p) for p in prompts]) with torch.no_grad(): text_features = clip_model.encode_text(prompts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) self.text_features = text_features self.clip_model = clip_model def forward(self, image): image_features = self.clip_model.encode_image(image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) logit_scale = self.clip_model.logit_scale.exp() text_features = self.text_features text_features = text_features.to(image_features.device) logits = logit_scale * image_features @ text_features.t() return logits class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.prompt_learner = PromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale self.dtype = clip_model.dtype def forward(self, image): image_features = self.image_encoder(image.type(self.dtype)) prompts = self.prompt_learner() tokenized_prompts = self.tokenized_prompts text_features = self.text_encoder(prompts, tokenized_prompts) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logit_scale = self.logit_scale.exp() logits = logit_scale * image_features @ text_features.t() return logits class ProGradLoss(_Loss): def __init__(self, T): super(ProGradLoss, self).__init__() self.T = T def forward(self, stu_logits, tea_logits, label): xe_loss = F.cross_entropy(stu_logits, label) tea_prob = F.softmax(tea_logits / self.T, dim=-1) kl_loss = -tea_prob * F.log_softmax(stu_logits / self.T, -1) * self.T * self.T kl_loss = kl_loss.sum(1).mean() return xe_loss, kl_loss @TRAINER_REGISTRY.register() class ProGrad(TrainerX): """Projected Gradient for few-shot CLIP """ def check_cfg(self, cfg): assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building zeroshot CLIP") self.zs_clip = CLIP(cfg, classnames) print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in ZS Clip model") for name, param in self.zs_clip.named_parameters(): param.requires_grad_(False) print("Turning off gradients in CoOp model") for name, param in self.model.named_parameters(): if "prompt_learner" not in name: param.requires_grad_(False) if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) self.zs_clip = self.zs_clip.cuda() # NOTE: only give prompt_learner to the optimizer self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() if device_count > 1: print( f"Multiple GPUs detected (n_gpus={device_count}), use all of them!" ) self.model = nn.DataParallel(self.model) self.zs_clip = nn.DataParallel(self.zs_clip) # build criterion if cfg.LOSS.NAME == "prograd": self.criterion = ProGradLoss(T=cfg.LOSS.T) else: raise NotImplementedError def forward_backward(self, batch): image, label = self.parse_batch_train(batch) prec = self.cfg.TRAINER.COOP.PREC if prec == "amp": with autocast(): output = self.model(image) with torch.no_grad(): zs_clip_output = self.zs_clip(image) loss = self.criterion(output, zs_clip_output.detach(), label) self.optim.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optim) self.scaler.update() else: output = self.model(image) with torch.no_grad(): zs_clip_output = self.zs_clip(image) xe_loss, kl_loss = self.criterion(output, zs_clip_output.detach(), label) self.prograd_backward_and_update(xe_loss, kl_loss, self.cfg.LOSS.LAMBDA) loss_summary = { "xe_loss": xe_loss.item(), "kl_loss": kl_loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print( "Note that load_model() is skipped as no pretrained model is given" ) return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError( 'Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors if "token_prefix" in state_dict: del state_dict["token_prefix"] if "token_suffix" in state_dict: del state_dict["token_suffix"] print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) ================================================ FILE: ProGrad.public/trainers/zsclip.py ================================================ import torch import torch.nn as nn from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.model import convert_weights from .coop import load_clip_to_cpu from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT CUSTOM_TEMPLATES = { # "OxfordPets": "a photo of a {}, a type of pet.", "OxfordPets": "a type of pet, a photo of a {}.", # "OxfordFlowers": "a photo of a {}, a type of flower.", "OxfordFlowers": "a type of flower, a photo of a {}.", "FGVCAircraft": "a photo of a {}, a type of aircraft.", "DescribableTextures": "{} texture.", "EuroSAT": "a centered satellite photo of {}.", "StanfordCars": "a photo of a {}.", # "Food101": "a photo of {}, a type of food.", "Food101": "a type of food, a photo of {}.", "SUN397": "a photo of a {}.", "Caltech101": "a photo of a {}.", "UCF101": "a photo of a person doing {}.", "ImageNet": "a photo of a {}.", "ImageNetSketch": "a photo of a {}.", "ImageNetV2": "a photo of a {}.", "ImageNetA": "a photo of a {}.", "ImageNetR": "a photo of a {}.", } @TRAINER_REGISTRY.register() class ZeroshotCLIP(TrainerX): def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) clip_model.to(self.device) temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] prompts = [temp.format(c.replace("_", " ")) for c in classnames] print(f"Prompts: {prompts}") prompts = torch.cat([clip.tokenize(p) for p in prompts]) prompts = prompts.to(self.device) with torch.no_grad(): text_features = clip_model.encode_text(prompts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) self.text_features = text_features self.clip_model = clip_model def model_inference(self, image): image_features = self.clip_model.encode_image(image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) logit_scale = self.clip_model.logit_scale.exp() logits = logit_scale * image_features @ self.text_features.t() return logits @TRAINER_REGISTRY.register() class ZeroshotCLIP2(ZeroshotCLIP): """Prompt ensembling.""" # templates = IMAGENET_TEMPLATES templates = IMAGENET_TEMPLATES_SELECT def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) clip_model.to(self.device) for params in clip_model.parameters(): params.requires_grad_(False) # add custom-made prompt if cfg.DATASET.NAME != "ImageNet": self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] num_temp = len(self.templates) print(f"Prompt ensembling (n={num_temp})") mean_text_features = 0 for i, temp in enumerate(self.templates): prompts = [temp.format(c.replace("_", " ")) for c in classnames] prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) text_features = clip_model.encode_text(prompts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) mean_text_features = mean_text_features + text_features mean_text_features = mean_text_features / num_temp mean_text_features = mean_text_features / mean_text_features.norm( dim=-1, keepdim=True) self.text_features = mean_text_features self.clip_model = clip_model ================================================ FILE: readme.md ================================================ # [ICCV23] Prompt-aligned Gradient for Prompt Tuning We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the “general direction”, which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods. ![image](ProGrad.public/Pipeline.png) [[paper link]](https://doi.org/10.48550/arxiv.2205.14865) The codes are organized into two folders: 1. [Dassl.ProGrad.pytorch](Dassl.ProGrad.pytorch/) is the modified toolbox of [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch). 2. [ProGrad.public](ProGrad.public/). To get the results in our paper, follow the [README.md](ProGrad.public/README.md) under [ProGrad.public/](ProGrad.public/) to set the environment. ## Citation If you find our paper or this project helps your research, please kindly consider citing our paper in your publication. ``` @inproceedings{https://doi.org/10.48550/arxiv.2205.14865, author = {Zhu, Beier and Niu, Yulei and Han, Yucheng and Wu, Yue and Zhang, Hanwang}, title = {Prompt-aligned Gradient for Prompt Tuning}, publisher = {International Conference on Computer Vision}, year = {2023}, } ``` ## Acknowledgement Our codes are built on top of [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch).