Repository: BeierZhu/Prompt-align
Branch: main
Commit: 1b762036cbe9
Files: 244
Total size: 606.4 KB
Directory structure:
gitextract_cxcff6dn/
├── Dassl.ProGrad.pytorch/
│ ├── .flake8
│ ├── .gitignore
│ ├── .isort.cfg
│ ├── .style.yapf
│ ├── DATASETS.md
│ ├── LICENSE
│ ├── README.md
│ ├── configs/
│ │ ├── README.md
│ │ ├── datasets/
│ │ │ ├── da/
│ │ │ │ ├── cifar_stl.yaml
│ │ │ │ ├── digit5.yaml
│ │ │ │ ├── domainnet.yaml
│ │ │ │ ├── mini_domainnet.yaml
│ │ │ │ ├── office31.yaml
│ │ │ │ ├── office_home.yaml
│ │ │ │ └── visda17.yaml
│ │ │ ├── dg/
│ │ │ │ ├── cifar100_c.yaml
│ │ │ │ ├── cifar10_c.yaml
│ │ │ │ ├── digit_single.yaml
│ │ │ │ ├── digits_dg.yaml
│ │ │ │ ├── office_home_dg.yaml
│ │ │ │ ├── pacs.yaml
│ │ │ │ └── vlcs.yaml
│ │ │ └── ssl/
│ │ │ ├── cifar10.yaml
│ │ │ ├── cifar100.yaml
│ │ │ ├── stl10.yaml
│ │ │ └── svhn.yaml
│ │ └── trainers/
│ │ ├── da/
│ │ │ ├── dael/
│ │ │ │ ├── digit5.yaml
│ │ │ │ ├── domainnet.yaml
│ │ │ │ └── mini_domainnet.yaml
│ │ │ ├── m3sda/
│ │ │ │ ├── digit5.yaml
│ │ │ │ ├── domainnet.yaml
│ │ │ │ └── mini_domainnet.yaml
│ │ │ └── source_only/
│ │ │ ├── digit5.yaml
│ │ │ ├── mini_domainnet.yaml
│ │ │ ├── office31.yaml
│ │ │ └── visda17.yaml
│ │ ├── dg/
│ │ │ ├── dael/
│ │ │ │ ├── digits_dg.yaml
│ │ │ │ ├── office_home_dg.yaml
│ │ │ │ └── pacs.yaml
│ │ │ ├── ddaig/
│ │ │ │ ├── digits_dg.yaml
│ │ │ │ ├── office_home_dg.yaml
│ │ │ │ └── pacs.yaml
│ │ │ └── vanilla/
│ │ │ ├── digits_dg.yaml
│ │ │ ├── mini_domainnet.yaml
│ │ │ ├── office_home_dg.yaml
│ │ │ └── pacs.yaml
│ │ └── ssl/
│ │ └── fixmatch/
│ │ └── cifar10.yaml
│ ├── dassl/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ └── defaults.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── data_manager.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_dataset.py
│ │ │ │ ├── build.py
│ │ │ │ ├── da/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── cifarstl.py
│ │ │ │ │ ├── digit5.py
│ │ │ │ │ ├── domainnet.py
│ │ │ │ │ ├── mini_domainnet.py
│ │ │ │ │ ├── office31.py
│ │ │ │ │ ├── office_home.py
│ │ │ │ │ └── visda17.py
│ │ │ │ ├── dg/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── cifar_c.py
│ │ │ │ │ ├── digit_single.py
│ │ │ │ │ ├── digits_dg.py
│ │ │ │ │ ├── office_home_dg.py
│ │ │ │ │ ├── pacs.py
│ │ │ │ │ └── vlcs.py
│ │ │ │ └── ssl/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cifar.py
│ │ │ │ ├── stl10.py
│ │ │ │ └── svhn.py
│ │ │ ├── samplers.py
│ │ │ └── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── autoaugment.py
│ │ │ ├── randaugment.py
│ │ │ └── transforms.py
│ │ ├── engine/
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── da/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adabn.py
│ │ │ │ ├── adda.py
│ │ │ │ ├── dael.py
│ │ │ │ ├── dann.py
│ │ │ │ ├── m3sda.py
│ │ │ │ ├── mcd.py
│ │ │ │ ├── mme.py
│ │ │ │ ├── self_ensembling.py
│ │ │ │ └── source_only.py
│ │ │ ├── dg/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── crossgrad.py
│ │ │ │ ├── daeldg.py
│ │ │ │ ├── ddaig.py
│ │ │ │ └── vanilla.py
│ │ │ ├── ssl/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── entmin.py
│ │ │ │ ├── fixmatch.py
│ │ │ │ ├── mean_teacher.py
│ │ │ │ ├── mixmatch.py
│ │ │ │ └── sup_baseline.py
│ │ │ └── trainer.py
│ │ ├── evaluation/
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ └── evaluator.py
│ │ ├── metrics/
│ │ │ ├── __init__.py
│ │ │ ├── accuracy.py
│ │ │ └── distance.py
│ │ ├── modeling/
│ │ │ ├── __init__.py
│ │ │ ├── backbone/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── alexnet.py
│ │ │ │ ├── backbone.py
│ │ │ │ ├── build.py
│ │ │ │ ├── cnn_digit5_m3sda.py
│ │ │ │ ├── cnn_digitsdg.py
│ │ │ │ ├── cnn_digitsingle.py
│ │ │ │ ├── efficientnet/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── model.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── mobilenetv2.py
│ │ │ │ ├── preact_resnet18.py
│ │ │ │ ├── resnet.py
│ │ │ │ ├── shufflenetv2.py
│ │ │ │ ├── vgg.py
│ │ │ │ └── wide_resnet.py
│ │ │ ├── head/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── build.py
│ │ │ │ └── mlp.py
│ │ │ ├── network/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── build.py
│ │ │ │ └── ddaig_fcn.py
│ │ │ └── ops/
│ │ │ ├── __init__.py
│ │ │ ├── cross_entropy.py
│ │ │ ├── dsbn.py
│ │ │ ├── efdmix.py
│ │ │ ├── mixstyle.py
│ │ │ ├── mixup.py
│ │ │ ├── mmd.py
│ │ │ ├── optimal_transport.py
│ │ │ ├── reverse_grad.py
│ │ │ ├── sequential2.py
│ │ │ ├── transnorm.py
│ │ │ └── utils.py
│ │ ├── optim/
│ │ │ ├── __init__.py
│ │ │ ├── lr_scheduler.py
│ │ │ ├── optimizer.py
│ │ │ └── radam.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── logger.py
│ │ ├── meters.py
│ │ ├── registry.py
│ │ ├── tools.py
│ │ └── torchtools.py
│ ├── datasets/
│ │ ├── da/
│ │ │ ├── cifar_stl.py
│ │ │ ├── digit5.py
│ │ │ └── visda17.sh
│ │ ├── dg/
│ │ │ └── cifar_c.py
│ │ └── ssl/
│ │ ├── cifar10_cifar100_svhn.py
│ │ └── stl10.py
│ ├── linter.sh
│ ├── requirements.txt
│ ├── setup.py
│ └── tools/
│ ├── parse_test_res.py
│ ├── replace_text.py
│ └── train.py
├── ProGrad.public/
│ ├── .gitignore
│ ├── DATASETS.md
│ ├── LICENSE
│ ├── README.md
│ ├── clip/
│ │ ├── __init__.py
│ │ ├── clip.py
│ │ ├── model.py
│ │ └── simple_tokenizer.py
│ ├── configs/
│ │ ├── datasets/
│ │ │ ├── caltech101.yaml
│ │ │ ├── dtd.yaml
│ │ │ ├── eurosat.yaml
│ │ │ ├── fgvc_aircraft.yaml
│ │ │ ├── food101.yaml
│ │ │ ├── imagenet.yaml
│ │ │ ├── imagenet_a.yaml
│ │ │ ├── imagenet_r.yaml
│ │ │ ├── imagenet_sketch.yaml
│ │ │ ├── imagenetv2.yaml
│ │ │ ├── oxford_flowers.yaml
│ │ │ ├── oxford_pets.yaml
│ │ │ ├── stanford_cars.yaml
│ │ │ ├── sun397.yaml
│ │ │ └── ucf101.yaml
│ │ └── trainers/
│ │ ├── CoCoOp/
│ │ │ ├── rn50_c4_ep10_batch1_ctxv1.yaml
│ │ │ ├── rn50_ep100_init.yaml
│ │ │ ├── rn50_ep50.yaml
│ │ │ ├── vit_b16_c16_ep10_batch1.yaml
│ │ │ ├── vit_b16_c4_ep10_batch1.yaml
│ │ │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│ │ │ └── vit_b16_c8_ep10_batch1.yaml
│ │ ├── CoOp/
│ │ │ ├── rn50.yaml
│ │ │ ├── rn50_ep100.yaml
│ │ │ ├── rn50_ep50.yaml
│ │ │ └── rn50_val.yaml
│ │ └── ProGrad/
│ │ ├── rn50.yaml
│ │ ├── rn50_ep100.yaml
│ │ └── rn50_ep50.yaml
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── caltech101.py
│ │ ├── dtd.py
│ │ ├── eurosat.py
│ │ ├── fgvc_aircraft.py
│ │ ├── food101.py
│ │ ├── imagenet.py
│ │ ├── imagenet_a.py
│ │ ├── imagenet_r.py
│ │ ├── imagenet_sketch.py
│ │ ├── imagenetv2.py
│ │ ├── oxford_flowers.py
│ │ ├── oxford_pets.py
│ │ ├── stanford_cars.py
│ │ ├── sun397.py
│ │ └── ucf101.py
│ ├── interpret_prompt.py
│ ├── lpclip/
│ │ ├── README.md
│ │ ├── feat_extractor.py
│ │ ├── feat_extractor.sh
│ │ ├── linear_probe.py
│ │ ├── linear_probe.sh
│ │ └── linear_probe_transfer.py
│ ├── parse_test_res.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── base2new_test_main.sh
│ │ ├── base2new_test_prograd.sh
│ │ ├── base2new_train_main.sh
│ │ ├── base2new_train_prograd.sh
│ │ ├── eval.sh
│ │ ├── main.sh
│ │ ├── prograd.sh
│ │ └── zeroshot.sh
│ ├── train.py
│ └── trainers/
│ ├── __init__.py
│ ├── cocoop.py
│ ├── coop.py
│ ├── imagenet_templates.py
│ ├── prograd.py
│ └── zsclip.py
└── readme.md
================================================
FILE CONTENTS
================================================
================================================
FILE: Dassl.ProGrad.pytorch/.flake8
================================================
[flake8]
ignore =
# At least two spaces before inline comment
E261,
# Line lengths are recommended to be no greater than 79 characters
E501,
# Missing whitespace around arithmetic operator
E226,
# Blank line contains whitespace
W293,
# Do not use bare 'except'
E722,
# Line break after binary operator
W504,
# Too many leading '#' for block comment
E266,
# line break before binary operator
W503,
# continuation line over-indented for hanging indent
E126
max-line-length = 79
exclude = __init__.py, build
================================================
FILE: Dassl.ProGrad.pytorch/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# OS X
.DS_Store
.Spotlight-V100
.Trashes
._*
# This project
output/
debug.sh
debug.py
================================================
FILE: Dassl.ProGrad.pytorch/.isort.cfg
================================================
[isort]
line_length=79
multi_line_output=6
length_sort=true
known_standard_library=numpy,setuptools
known_myself=dassl
known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
default_section=FIRSTPARTY
================================================
FILE: Dassl.ProGrad.pytorch/.style.yapf
================================================
[style]
BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
DEDENT_CLOSING_BRACKETS = true
SPACES_BEFORE_COMMENT = 2
ARITHMETIC_PRECEDENCE_INDICATION = true
================================================
FILE: Dassl.ProGrad.pytorch/DATASETS.md
================================================
# How to Install Datasets
`$DATA` denotes the location where datasets are installed, e.g.
```
$DATA/
|–– office31/
|–– office_home/
|–– visda17/
```
[Domain Adaptation](#domain-adaptation)
- [Office-31](#office-31)
- [Office-Home](#office-home)
- [VisDA17](#visda17)
- [CIFAR10-STL10](#cifar10-stl10)
- [Digit-5](#digit-5)
- [DomainNet](#domainnet)
- [miniDomainNet](#miniDomainNet)
[Domain Generalization](#domain-generalization)
- [PACS](#pacs)
- [VLCS](#vlcs)
- [Office-Home-DG](#office-home-dg)
- [Digits-DG](#digits-dg)
- [Digit-Single](#digit-single)
- [CIFAR-10-C](#cifar-10-c)
- [CIFAR-100-C](#cifar-100-c)
[Semi-Supervised Learning](#semi-supervised-learning)
- [CIFAR10/100 and SVHN](#cifar10100-and-svhn)
- [STL10](#stl10)
## Domain Adaptation
### Office-31
Download link: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code.
File structure:
```
office31/
|–– amazon/
| |–– back_pack/
| |–– bike/
| |–– ...
|–– dslr/
| |–– back_pack/
| |–– bike/
| |–– ...
|–– webcam/
| |–– back_pack/
| |–– bike/
| |–– ...
```
Note that within each domain folder you need to move all class folders out of the `images/` folder and then delete the `images/` folder.
### Office-Home
Download link: http://hemanthdv.org/OfficeHome-Dataset/.
File structure:
```
office_home/
|–– art/
|–– clipart/
|–– product/
|–– real_world/
```
### VisDA17
Download link: http://ai.bu.edu/visda-2017/.
The dataset can also be downloaded using our script at `datasets/da/visda17.sh`. Run the following command in your terminal under `Dassl.pytorch/datasets/da`,
```bash
sh visda17.sh $DATA
```
Once the download is finished, the file structure will look like
```
visda17/
|–– train/
|–– test/
|–– validation/
```
### CIFAR10-STL10
Run the following command in your terminal under `Dassl.pytorch/datasets/da`,
```bash
python cifar_stl.py $DATA/cifar_stl
```
This will create a folder named `cifar_stl` under `$DATA`. The file structure will look like
```
cifar_stl/
|–– cifar/
| |–– train/
| |–– test/
|–– stl/
| |–– train/
| |–– test/
```
Note that only 9 classes shared by both datasets are kept.
### Digit-5
Create a folder `$DATA/digit5` and download to this folder the dataset from [here](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download). This should give you
```
digit5/
|–– Digit-Five/
```
Then, run the following command in your terminal under `Dassl.pytorch/datasets/da`,
```bash
python digit5.py $DATA/digit5
```
This will extract the data and organize the file structure as
```
digit5/
|–– Digit-Five/
|–– mnist/
|–– mnist_m/
|–– usps/
|–– svhn/
|–– syn/
```
### DomainNet
Download link: http://ai.bu.edu/M3SDA/. (Please download the cleaned version of split files)
File structure:
```
domainnet/
|–– clipart/
|–– infograph/
|–– painting/
|–– quickdraw/
|–– real/
|–– sketch/
|–– splits/
| |–– clipart_train.txt
| |–– clipart_test.txt
| |–– ...
```
### miniDomainNet
You need to download the DomainNet dataset first. The miniDomainNet's split files can be downloaded at this [google drive](https://drive.google.com/open?id=15rrLDCrzyi6ZY-1vJar3u7plgLe4COL7). After the zip file is extracted, you should have the folder `$DATA/domainnet/splits_mini/`.
## Domain Generalization
### PACS
Download link: [google drive](https://drive.google.com/open?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE).
File structure:
```
pacs/
|–– images/
|–– splits/
```
You do not necessarily have to manually download this dataset. Once you run ``tools/train.py``, the code will detect if the dataset exists or not and automatically download the dataset to ``$DATA`` if missing. This also applies to VLCS, Office-Home-DG, and Digits-DG.
### VLCS
Download link: [google drive](https://drive.google.com/file/d/1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd/view?usp=sharing) (credit to https://github.com/fmcarlucci/JigenDG#vlcs)
File structure:
```
VLCS/
|–– CALTECH/
|–– LABELME/
|–– PASCAL/
|–– SUN/
```
### Office-Home-DG
Download link: [google drive](https://drive.google.com/open?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa).
File structure:
```
office_home_dg/
|–– art/
|–– clipart/
|–– product/
|–– real_world/
```
### Digits-DG
Download link: [google driv](https://drive.google.com/open?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7).
File structure:
```
digits_dg/
|–– mnist/
|–– mnist_m/
|–– svhn/
|–– syn/
```
### Digit-Single
Follow the steps for [Digit-5](#digit-5) to organize the dataset.
### CIFAR-10-C
First download the CIFAR-10-C dataset from https://zenodo.org/record/2535967#.YFxHEWQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
```bash
python cifar_c.py $DATA/CIFAR-10-C
```
where the first argument denotes the path to the (uncompressed) CIFAR-10-C dataset.
The script will extract images from the `.npy` files and save them to `cifar10_c/` created under $DATA. The file structure will look like
```
cifar10_c/
|–– brightness/
| |–– 1/ # 5 intensity levels in total
| |–– 2/
| |–– 3/
| |–– 4/
| |–– 5/
|–– ... # 19 corruption types in total
```
Note that `cifar10_c/` only contains the test images. The training images are the normal CIFAR-10 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-10 dataset.
### CIFAR-100-C
First download the CIFAR-100-C dataset from https://zenodo.org/record/3555552#.YFxpQmQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
```bash
python cifar_c.py $DATA/CIFAR-100-C
```
where the first argument denotes the path to the (uncompressed) CIFAR-100-C dataset.
The script will extract images from the `.npy` files and save them to `cifar100_c/` created under $DATA. The file structure will look like
```
cifar100_c/
|–– brightness/
| |–– 1/ # 5 intensity levels in total
| |–– 2/
| |–– 3/
| |–– 4/
| |–– 5/
|–– ... # 19 corruption types in total
```
Note that `cifar100_c/` only contains the test images. The training images are the normal CIFAR-100 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-100 dataset.
## Semi-Supervised Learning
### CIFAR10/100 and SVHN
Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,
```bash
python cifar10_cifar100_svhn.py $DATA
```
This will create three folders under `$DATA`, i.e.
```
cifar10/
|–– train/
|–– test/
cifar100/
|–– train/
|–– test/
svhn/
|–– train/
|–– test/
```
### STL10
Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,
```bash
python stl10.py $DATA/stl10
```
This will create a folder named `stl10` under `$DATA` and extract the data into three folders, i.e. `train`, `test` and `unlabeled`. Then, download from http://ai.stanford.edu/~acoates/stl10/ the "Binary files" and extract it under `stl10`.
The file structure will look like
```
stl10/
|–– train/
|–– test/
|–– unlabeled/
|–– stl10_binary/
```
================================================
FILE: Dassl.ProGrad.pytorch/LICENSE
================================================
MIT License
Copyright (c) 2020 Kaiyang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Dassl.ProGrad.pytorch/README.md
================================================
# Dassl
## Introduction
Dassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain Adaptive Ensemble Learning (DAEL)](https://arxiv.org/abs/2003.07325) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter.
Why the name "Dassl"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative.
Dassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the [engine](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl/engine) folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-)
Basically, Dassl is perfect for doing research in the following areas:
- Domain adaptation
- Domain generalization
- Semi-supervised learning
BUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects, like [this](https://github.com/KaiyangZhou/CoOp). :-)
A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training (Dassl uses `DataParallel` to wrap a model, which is less efficient than `DistributedDataParallel`).
We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)
## What's new
- Mar 2022: A new domain generalization method [EFDM](https://arxiv.org/abs/2203.07740) developed by [Yabin Zhang (PolyU)](https://ybzh.github.io/) and to appear at CVPR'22 is added to this repo. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/pull/36) for more details.
- Feb 2022: In case you don't know, a class in the painting domain of DomainNet (the official splits) only has test images (no training images), which could affect performance. See section 4.a in our [paper](https://arxiv.org/abs/2003.07325) for more details.
- Oct 2021: `v0.5.0`: **Important changes** made to `transforms.py`. 1) `center_crop` becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio). 2) For training, `Resize(cfg.INPUT.SIZE)` is deactivated when `random_crop` or `random_resized_crop` is used. These changes won't make any difference to the training transforms used in existing config files, nor to the testing transforms unless the raw images are not squared (the only difference is that now the image aspect ratio is respected).
- Oct 2021: `v0.4.3`: Copy the attributes in `self.dm` (data manager) to `SimpleTrainer` and make `self.dm` optional, which means from now on, you can build data loaders from any source you like rather than being forced to use `DataManager`.
- Sep 2021: `v0.4.2`: An important update is to set `drop_last=is_train and len(data_source)>=batch_size` when constructing a data loader to avoid 0-length.
More
- Aug 2021: `v0.4.0`: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py#L10) and the config variables are specified [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L171).
- Jul 2021: `v0.3.4`: Adds a new function `generate_fewshot_dataset()` to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying `_C.DATASET.NUM_SHOTS` and give it to `generate_fewshot_dataset()`.
- Jul 2021: `v0.3.2`: Adds `_C.INPUT.INTERPOLATION` (default: `bilinear`). Available interpolation modes are `bilinear`, `nearest`, and `bicubic`.
- Jul 2021 `v0.3.1`: Now you can use `*.register(force=True)` to replace previously registered modules.
- Jul 2021 `v0.3.0`: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named `_C.TEST.FINAL_MODEL` is introduced, which takes either `"last_step"` (default) or `"best_val"`. When set to `"best_val"`, the model will be evaluated on the `val` set after each epoch and the one with the best validation performance will be saved and used for final test (see this [code](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/trainer.py#L412)).
- Jul 2021 `v0.2.7`: Adds attribute `classnames` to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling `trainer.dm.dataset.classnames`.
- Jun 2021 `v0.2.6`: Merges `MixStyle2` to `MixStyle`. A new variable `self.mix` is used to switch between random mixing and cross-domain mixing. Please see [this](https://github.com/KaiyangZhou/Dassl.pytorch/issues/23) for more details on the new features.
- Jun 2021 `v0.2.5`: Fixs a [bug](https://github.com/KaiyangZhou/Dassl.pytorch/commit/29881c7faee7405f80f5f674de4bbbf80d5dc77a) in the calculation of per-class recognition accuracy.
- Jun 2021 `v0.2.4`: Adds `extend_cfg(cfg)` to `train.py`. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository [mixstyle-release](https://github.com/KaiyangZhou/mixstyle-release) or [ssdg-benchmark](https://github.com/KaiyangZhou/ssdg-benchmark) for examples.
- Jun 2021 New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.
- Apr 2021 Do you know you can use `tools/parse_test_res.py` to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in `tools/parse_test_res.py` for more details.
- Apr 2021 `v0.2.3`: A [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) layer can now be deactivated or activated by using `model.apply(deactivate_mixstyle)` or `model.apply(activate_mixstyle)` without modifying the source code. See [dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py) for the details.
- Apr 2021 `v0.2.2`: Adds `RandomClassSampler`, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from [Torchreid](https://github.com/KaiyangZhou/deep-person-reid)).
- Apr 2021 `v0.2.1`: Slightly adjusts the ordering in `setup_cfg()` (see `tools/train.py`).
- Apr 2021 `v0.2.0`: Adds `_C.DATASET.ALL_AS_UNLABELED` (for the SSL setting) to the config variable list. When this variable is set to `True`, all labeled data will be included in the unlabeled data set.
- Apr 2021 `v0.1.9`: Adds [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) to the benchmark datasets (see `dassl/data/datasets/dg/vlcs.py`).
- Mar 2021 `v0.1.8`: Allows `optim` and `sched` to be `None` in `register_model()`.
- Mar 2021 `v0.1.7`: Adds [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) models to [dassl/modeling/backbone/resnet.py](dassl/modeling/backbone/resnet.py). The training configs in `configs/trainers/dg/vanilla` can be used to train MixStyle models.
- Mar 2021 `v0.1.6`: Adds [CIFAR-10/100-C](https://arxiv.org/abs/1807.01697) to the benchmark datasets for evaluating a model's robustness to image corruptions.
- Mar 2021 We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.
- Jan 2021 Our recent work, [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.
- May 2020 `v0.1.3`: Adds the `Digit-Single` dataset for benchmarking single-source DG methods. The corresponding CNN model is [dassl/modeling/backbone/cnn_digitsingle.py](dassl/modeling/backbone/cnn_digitsingle.py) and the dataset config file is [configs/datasets/dg/digit_single.yaml](configs/datasets/dg/digit_single.yaml). See [Volpi et al. NIPS'18](https://arxiv.org/abs/1805.12018) for how to do evaluation.
- May 2020 `v0.1.2`: 1) Adds [EfficientNet](https://arxiv.org/abs/1905.11946) models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set `MODEL.BACKBONE.NAME` to `efficientnet_b{N}` where `N={0, ..., 7}`. 2) `dassl/modeling/models` is renamed to `dassl/modeling/network` (`build_model()` to `build_network()` and `MODEL_REGISTRY` to `NETWORK_RESIGTRY`).
## Overview
Dassl has implemented the following methods:
- Single-source domain adaptation
- [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)]
- [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)]
- [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)]
- [Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)](https://arxiv.org/abs/1603.04779) [[dassl/engine/da/adabn.py](dassl/engine/da/adabn.py)]
- [Adversarial Discriminative Domain Adaptation (CVPR'17)](https://arxiv.org/abs/1702.05464) [[dassl/engine/da/adda.py](dassl/engine/da/adda.py)]
- [Domain-Adversarial Training of Neural Networks (JMLR'16) ](https://arxiv.org/abs/1505.07818) [[dassl/engine/da/dann.py](dassl/engine/da/dann.py)]
- Multi-source domain adaptation
- [Domain Aadaptive Ensemble Learning](https://arxiv.org/abs/2003.07325) [[dassl/engine/da/dael.py](dassl/engine/da/dael.py)]
- [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)]
- Domain generalization
- [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)]
- [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)]
- [Deep Domain-Adversarial Image Generation for Domain Generalisation (AAAI'20)](https://arxiv.org/abs/2003.06054) [[dassl/engine/dg/ddaig.py](dassl/engine/dg/ddaig.py)]
- [Generalizing Across Domains via Cross-Gradient Training (ICLR'18)](https://arxiv.org/abs/1804.10745) [[dassl/engine/dg/crossgrad.py](dassl/engine/dg/crossgrad.py)]
- Semi-supervised learning
- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685) [[dassl/engine/ssl/fixmatch.py](dassl/engine/ssl/fixmatch.py)]
- [MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)](https://arxiv.org/abs/1905.02249) [[dassl/engine/ssl/mixmatch.py](dassl/engine/ssl/mixmatch.py)]
- [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)](https://arxiv.org/abs/1703.01780) [[dassl/engine/ssl/mean_teacher.py](dassl/engine/ssl/mean_teacher.py)]
- [Semi-supervised Learning by Entropy Minimization (NeurIPS'04)](http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf) [[dassl/engine/ssl/entmin.py](dassl/engine/ssl/entmin.py)]
*Feel free to make a [PR](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to add your methods here to make it easier for others to benchmark!*
Dassl supports the following datasets:
- Domain adaptation
- [Office-31](https://scalable.mpi-inf.mpg.de/files/2013/04/saenko_eccv_2010.pdf)
- [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
- [VisDA17](http://ai.bu.edu/visda-2017/)
- [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)-[STL10](https://cs.stanford.edu/~acoates/stl10/)
- [Digit-5](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download)
- [DomainNet](http://ai.bu.edu/M3SDA/)
- [miniDomainNet](https://arxiv.org/abs/2003.07325)
- Domain generalization
- [PACS](https://arxiv.org/abs/1710.03077)
- [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf)
- [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
- [Digits-DG](https://arxiv.org/abs/2003.06054)
- [Digit-Single](https://arxiv.org/abs/1805.12018)
- [CIFAR-10-C](https://arxiv.org/abs/1807.01697)
- [CIFAR-100-C](https://arxiv.org/abs/1807.01697)
- Semi-supervised learning
- [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html.)
- [SVHN](http://ufldl.stanford.edu/housenumbers/)
- [STL10](https://cs.stanford.edu/~acoates/stl10/)
## Get started
### Installation
Make sure [conda](https://www.anaconda.com/distribution/) is installed properly.
```bash
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -n dassl python=3.7
# Activate the environment
conda activate dassl
# Install dependencies
pip install -r requirements.txt
# Install torch (version >= 1.7.1) and torchvision
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
```
Follow the instructions in [DATASETS.md](./DATASETS.md) to preprocess the datasets.
### Training
The main interface is implemented in `tools/train.py`, which basically does
1. initialize the config with `cfg = setup_cfg(args)` where `args` contains the command-line input (see `tools/train.py` for the list of input arguments);
2. instantiate a `trainer` with `build_trainer(cfg)` which loads the dataset and builds a deep neural network model;
3. call `trainer.train()` for training and evaluating the model.
Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,
```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31
```
`$DATA` denotes the location where datasets are installed. `--dataset-config-file` loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. `--config-file` loads the algorithm-specific setting such as hyper-parameters and optimization parameters.
To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to `--source-domains`. For instance, to train a source-only baseline on miniDomainNet, one can do
```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn
```
After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.
To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use `tools/parse_test_res.py`. The instruction can be found in the code.
For other trainers such as `MCD`, you can set `--trainer MCD` while keeping the config file unchanged, i.e. using the same training parameters as `SourceOnly` (in the simplest case). To modify the hyper-parameters in MCD, like `N_STEP_F` (number of steps to update the feature extractor), you can append `TRAINER.MCD.N_STEP_F 4` to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new `.yaml` config file to store your custom setting. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L176) for a complete list of algorithm-specific hyper-parameters.
### Test
Model testing can be done by using `--eval-only`, which asks the code to run `trainer.test()`. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use `model.pth.tar-20` saved at `output/source_only_office31/model`, you can do
```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20
```
Note that `--model-dir` takes as input the directory path which was specified in `--output-dir` in the training stage.
### Write a new trainer
A good practice is to go through `dassl/engine/trainer.py` to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass `TrainerXU`. For domain generalization, the new class can subclass `TrainerX`. In particular, `TrainerXU` and `TrainerX` mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the `forward_backward()` method, which performs loss computation and model update. See `dassl/enigne/da/source_only.py` for example.
### Add a new backbone/head/network
`backbone` corresponds to a convolutional neural network model which performs feature extraction. `head` (which is an optional module) is mounted on top of `backbone` for further processing, which can be, for example, a MLP. `backbone` and `head` are basic building blocks for constructing a `SimpleNet()` (see `dassl/engine/trainer.py`) which serves as the primary model for a task. `network` contains custom neural network models, such as an image generator.
To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding `registry`, i.e. `BACKBONE_REGISTRY` for `backbone`, `HEAD_REGISTRY` for `head` and `NETWORK_RESIGTRY` for `network`. Note that for a new `backbone`, we require the model to subclass `Backbone` as defined in `dassl/modeling/backbone/backbone.py` and specify the `self._out_features` attribute.
We provide an example below for how to add a new `backbone`.
```python
from dassl.modeling import Backbone, BACKBONE_REGISTRY
class MyBackbone(Backbone):
def __init__(self):
super().__init__()
# Create layers
self.conv = ...
self._out_features = 2048
def forward(self, x):
# Extract and return features
@BACKBONE_REGISTRY.register()
def my_backbone(**kwargs):
return MyBackbone()
```
Then, you can set `MODEL.BACKBONE.NAME` to `my_backbone` to use your own architecture. For more details, please refer to the source code in `dassl/modeling`.
### Add a dataset
An example code structure is shown below. Make sure you subclass `DatasetBase` and register the dataset with `@DATASET_REGISTRY.register()`. All you need is to load `train_x`, `train_u` (optional), `val` (optional) and `test`, among which `train_u` and `val` could be `None` or simply ignored. Each of these variables contains a list of `Datum` objects. A `Datum` object (implemented [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/base_dataset.py#L12)) contains information for a single image, like `impath` (string) and `label` (int).
```python
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
@DATASET_REGISTRY.register()
class NewDataset(DatasetBase):
dataset_dir = ''
def __init__(self, cfg):
train_x = ...
train_u = ... # optional, can be None
val = ... # optional, can be None
test = ...
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
```
We suggest you take a look at the datasets code in some projects like [this](https://github.com/KaiyangZhou/CoOp), which is built on top of Dassl.
## Relevant Research
We would like to share here our research relevant to Dassl.
- [Domain Adaptive Ensemble Learning](https://arxiv.org/abs/2003.07325), TIP, 2021.
- [MixStyle Neural Networks for Domain Generalization and Adaptation](https://arxiv.org/abs/2107.02053), arxiv preprint, 2021.
- [Semi-Supervised Domain Generalization with Stochastic StyleMatch](https://arxiv.org/abs/2106.00592), arxiv preprint, 2021.
- [Domain Generalization in Vision: A Survey](https://arxiv.org/abs/2103.02503), arxiv preprint, 2021.
- [Domain Generalization with MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp), in ICLR 2021.
- [Learning to Generate Novel Domains for Domain Generalization](https://arxiv.org/abs/2007.03304), in ECCV 2020.
- [Deep Domain-Adversarial Image Generation for Domain Generalisation](https://arxiv.org/abs/2003.06054), in AAAI 2020.
## Citation
If you find this code useful to your research, please give credit to the following paper
```
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={IEEE Transactions on Image Processing (TIP)},
year={2021}
}
```
================================================
FILE: Dassl.ProGrad.pytorch/configs/README.md
================================================
The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings.
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
================================================
INPUT:
SIZE: (32, 32)
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "CIFARSTL"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
================================================
INPUT:
SIZE: (32, 32)
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
TRANSFORMS: ["normalize"]
DATASET:
NAME: "Digit5"
MODEL:
BACKBONE:
NAME: "cnn_digit5_m3sda"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "DomainNet"
MODEL:
BACKBONE:
NAME: "resnet101"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/mini_domainnet.yaml
================================================
INPUT:
SIZE: (96, 96)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "miniDomainNet"
MODEL:
BACKBONE:
NAME: "resnet18"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "Office31"
MODEL:
BACKBONE:
NAME: "resnet50"
HEAD:
NAME: "mlp"
HIDDEN_LAYERS: [256]
DROPOUT: 0.
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office_home.yaml
================================================
INPUT:
SIZE: (224, 224)
DATASET:
NAME: "OfficeHome"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "center_crop", "normalize"]
DATASET:
NAME: "VisDA17"
MODEL:
BACKBONE:
NAME: "resnet101"
TEST:
PER_CLASS_RESULT: True
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "CIFAR100C"
CIFAR_C_TYPE: "fog"
CIFAR_C_LEVEL: 5
MODEL:
BACKBONE:
NAME: "wide_resnet_16_4"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "CIFAR10C"
CIFAR_C_TYPE: "fog"
CIFAR_C_LEVEL: 5
MODEL:
BACKBONE:
NAME: "wide_resnet_16_4"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "DigitSingle"
MODEL:
BACKBONE:
NAME: "cnn_digitsingle"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "DigitsDG"
MODEL:
BACKBONE:
NAME: "cnn_digitsdg"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/office_home_dg.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "OfficeHomeDG"
MODEL:
BACKBONE:
NAME: "resnet18"
PRETRAINED: True
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "PACS"
MODEL:
BACKBONE:
NAME: "resnet18"
PRETRAINED: True
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
================================================
INPUT:
SIZE: (224, 224)
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
DATASET:
NAME: "VLCS"
MODEL:
BACKBONE:
NAME: "resnet18"
PRETRAINED: True
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
DATASET:
NAME: "CIFAR10"
NUM_LABELED: 4000
VAL_PERCENT: 0.
MODEL:
BACKBONE:
NAME: "wide_resnet_28_2"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
CROP_PADDING: 4
DATASET:
NAME: "CIFAR100"
NUM_LABELED: 10000
VAL_PERCENT: 0.
MODEL:
BACKBONE:
NAME: "wide_resnet_28_2"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
================================================
INPUT:
SIZE: (96, 96)
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
CROP_PADDING: 4
DATASET:
NAME: "STL10"
STL10_FOLD: 0
MODEL:
BACKBONE:
NAME: "wide_resnet_28_2"
================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
================================================
INPUT:
SIZE: (32, 32)
TRANSFORMS: ["random_crop", "normalize"]
PIXEL_MEAN: [0.5, 0.5, 0.5]
PIXEL_STD: [0.5, 0.5, 0.5]
CROP_PADDING: 4
DATASET:
NAME: "SVHN"
NUM_LABELED: 1000
VAL_PERCENT: 0.
MODEL:
BACKBONE:
NAME: "wide_resnet_28_2"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
================================================
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 256
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 256
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [30]
MAX_EPOCH: 30
LR_SCHEDULER: "cosine"
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 4
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 30
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 6
TEST:
BATCH_SIZE: 30
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 40
LR_SCHEDULER: "cosine"
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/mini_domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 8
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 192
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 200
OPTIM:
NAME: "sgd"
LR: 0.005
MAX_EPOCH: 60
LR_SCHEDULER: "cosine"
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
================================================
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 256
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 256
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [30]
MAX_EPOCH: 30
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 4
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 30
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 6
TEST:
BATCH_SIZE: 30
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 40
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 8
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 192
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 200
OPTIM:
NAME: "sgd"
LR: 0.005
MAX_EPOCH: 60
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/digit5.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 256
TEST:
BATCH_SIZE: 256
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [30]
MAX_EPOCH: 30
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 8
TRAIN_X:
BATCH_SIZE: 128
TEST:
BATCH_SIZE: 128
OPTIM:
NAME: "sgd"
LR: 0.005
MAX_EPOCH: 60
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/office31.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 32
OPTIM:
NAME: "sgd"
LR: 0.002
STEPSIZE: [20]
MAX_EPOCH: 20
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/visda17.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 32
OPTIM:
NAME: "sgd"
LR: 0.0001
STEPSIZE: [2]
MAX_EPOCH: 2
TRAIN:
PRINT_FREQ: 50
COUNT_ITER: "train_u"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/digits_dg.yaml
================================================
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 120
TEST:
BATCH_SIZE: 100
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [20]
MAX_EPOCH: 50
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/office_home_dg.yaml
================================================
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 30
TEST:
BATCH_SIZE: 100
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 40
LR_SCHEDULER: "cosine"
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
================================================
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 30
TEST:
BATCH_SIZE: 100
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 40
LR_SCHEDULER: "cosine"
TRAINER:
DAEL:
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml
================================================
INPUT:
PIXEL_MEAN: [0., 0., 0.]
PIXEL_STD: [1., 1., 1.]
DATALOADER:
TRAIN_X:
BATCH_SIZE: 128
TEST:
BATCH_SIZE: 128
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [20]
MAX_EPOCH: 50
TRAINER:
DDAIG:
G_ARCH: "fcn_3x32_gctx"
LMDA: 0.3
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml
================================================
INPUT:
PIXEL_MEAN: [0., 0., 0.]
PIXEL_STD: [1., 1., 1.]
DATALOADER:
TRAIN_X:
BATCH_SIZE: 16
TEST:
BATCH_SIZE: 16
OPTIM:
NAME: "sgd"
LR: 0.0005
STEPSIZE: [20]
MAX_EPOCH: 25
TRAINER:
DDAIG:
G_ARCH: "fcn_3x64_gctx"
WARMUP: 3
LMDA: 0.3
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
================================================
INPUT:
PIXEL_MEAN: [0., 0., 0.]
PIXEL_STD: [1., 1., 1.]
DATALOADER:
TRAIN_X:
BATCH_SIZE: 16
TEST:
BATCH_SIZE: 16
OPTIM:
NAME: "sgd"
LR: 0.0005
STEPSIZE: [20]
MAX_EPOCH: 25
TRAINER:
DDAIG:
G_ARCH: "fcn_3x64_gctx"
WARMUP: 3
LMDA: 0.3
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 128
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [20]
MAX_EPOCH: 50
TRAIN:
PRINT_FREQ: 20
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml
================================================
DATALOADER:
NUM_WORKERS: 8
TRAIN_X:
BATCH_SIZE: 128
TEST:
BATCH_SIZE: 128
OPTIM:
NAME: "sgd"
LR: 0.005
MAX_EPOCH: 60
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
OPTIM:
NAME: "sgd"
LR: 0.001
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 64
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
OPTIM:
NAME: "sgd"
LR: 0.001
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 64
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 448
TEST:
BATCH_SIZE: 500
OPTIM:
NAME: "sgd"
LR: 0.05
STEPSIZE: [4000]
MAX_EPOCH: 4000
LR_SCHEDULER: "cosine"
TRAIN:
COUNT_ITER: "train_u"
PRINT_FREQ: 10
TRAINER:
FIXMATCH:
STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"]
================================================
FILE: Dassl.ProGrad.pytorch/dassl/__init__.py
================================================
"""
Dassl
------
PyTorch toolbox for domain adaptation and semi-supervised learning.
URL: https://github.com/KaiyangZhou/Dassl.pytorch
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={arXiv preprint arXiv:2003.07325},
year={2020}
}
"""
__version__ = "0.5.0"
__author__ = "Kaiyang Zhou"
__homepage__ = "https://kaiyangzhou.github.io/"
================================================
FILE: Dassl.ProGrad.pytorch/dassl/config/__init__.py
================================================
from .defaults import _C as cfg_default
def get_cfg_default():
return cfg_default.clone()
================================================
FILE: Dassl.ProGrad.pytorch/dassl/config/defaults.py
================================================
from yacs.config import CfgNode as CN
###########################
# Config definition
###########################
_C = CN()
_C.VERSION = 1
# Directory to save the output files (like log.txt and model weights)
_C.OUTPUT_DIR = "./output"
# Path to a directory where the files were saved previously
_C.RESUME = ""
# Set seed to negative value to randomize everything
# Set seed to positive value to use a fixed seed
_C.SEED = -1
_C.USE_CUDA = True
# Print detailed information
# E.g. trainer, dataset, and backbone
_C.VERBOSE = True
###########################
# Input
###########################
_C.INPUT = CN()
_C.INPUT.SIZE = (224, 224)
# Mode of interpolation in resize functions
_C.INPUT.INTERPOLATION = "bilinear"
# For available choices please refer to transforms.py
_C.INPUT.TRANSFORMS = ()
# If True, tfm_train and tfm_test will be None
_C.INPUT.NO_TRANSFORM = False
# Default mean and std come from ImageNet
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
# Padding for random crop
_C.INPUT.CROP_PADDING = 4
# Cutout
_C.INPUT.CUTOUT_N = 1
_C.INPUT.CUTOUT_LEN = 16
# Gaussian noise
_C.INPUT.GN_MEAN = 0.0
_C.INPUT.GN_STD = 0.15
# RandomAugment
_C.INPUT.RANDAUGMENT_N = 2
_C.INPUT.RANDAUGMENT_M = 10
# ColorJitter (brightness, contrast, saturation, hue)
_C.INPUT.COLORJITTER_B = 0.4
_C.INPUT.COLORJITTER_C = 0.4
_C.INPUT.COLORJITTER_S = 0.4
_C.INPUT.COLORJITTER_H = 0.1
# Random gray scale's probability
_C.INPUT.RGS_P = 0.2
# Gaussian blur
_C.INPUT.GB_P = 0.5 # propability of applying this operation
_C.INPUT.GB_K = 21 # kernel size (should be an odd number)
###########################
# Dataset
###########################
_C.DATASET = CN()
# Directory where datasets are stored
_C.DATASET.ROOT = ""
_C.DATASET.NAME = ""
# List of names of source domains
_C.DATASET.SOURCE_DOMAINS = ()
# List of names of target domains
_C.DATASET.TARGET_DOMAINS = ()
# Number of labeled instances in total
# Useful for the semi-supervised learning
_C.DATASET.NUM_LABELED = -1
# Number of images per class
_C.DATASET.NUM_SHOTS = -1
# Percentage of validation data (only used for SSL datasets)
# Set to 0 if do not want to use val data
# Using val data for hyperparameter tuning was done in Oliver et al. 2018
_C.DATASET.VAL_PERCENT = 0.1
# Fold index for STL-10 dataset (normal range is 0 - 9)
# Negative number means None
_C.DATASET.STL10_FOLD = -1
# CIFAR-10/100-C's corruption type and intensity level
_C.DATASET.CIFAR_C_TYPE = ""
_C.DATASET.CIFAR_C_LEVEL = 1
# Use all data in the unlabeled data set (e.g. FixMatch)
_C.DATASET.ALL_AS_UNLABELED = False
###########################
# Dataloader
###########################
_C.DATALOADER = CN()
_C.DATALOADER.NUM_WORKERS = 4
# Apply transformations to an image K times (during training)
_C.DATALOADER.K_TRANSFORMS = 1
# img0 denotes image tensor without augmentation
# Useful for consistency learning
_C.DATALOADER.RETURN_IMG0 = False
# Setting for the train_x data-loader
_C.DATALOADER.TRAIN_X = CN()
_C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
# Parameter for RandomDomainSampler
# 0 or -1 means sampling from all domains
_C.DATALOADER.TRAIN_X.N_DOMAIN = 0
# Parameter of RandomClassSampler
# Number of instances per class
_C.DATALOADER.TRAIN_X.N_INS = 16
# Setting for the train_u data-loader
_C.DATALOADER.TRAIN_U = CN()
# Set to false if you want to have unique
# data loader params for train_u
_C.DATALOADER.TRAIN_U.SAME_AS_X = True
_C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
_C.DATALOADER.TRAIN_U.N_DOMAIN = 0
_C.DATALOADER.TRAIN_U.N_INS = 16
# Setting for the test data-loader
_C.DATALOADER.TEST = CN()
_C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
_C.DATALOADER.TEST.BATCH_SIZE = 32
###########################
# Model
###########################
_C.MODEL = CN()
# Path to model weights (for initialization)
_C.MODEL.INIT_WEIGHTS = ""
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = ""
_C.MODEL.BACKBONE.PRETRAINED = True
# Definition of embedding layers
_C.MODEL.HEAD = CN()
# If none, do not construct embedding layers, the
# backbone's output will be passed to the classifier
_C.MODEL.HEAD.NAME = ""
# Structure of hidden layers (a list), e.g. [512, 512]
# If undefined, no embedding layer will be constructed
_C.MODEL.HEAD.HIDDEN_LAYERS = ()
_C.MODEL.HEAD.ACTIVATION = "relu"
_C.MODEL.HEAD.BN = True
_C.MODEL.HEAD.DROPOUT = 0.0
###########################
# Optimization
###########################
_C.OPTIM = CN()
_C.OPTIM.NAME = "adam"
_C.OPTIM.LR = 0.0003
_C.OPTIM.WEIGHT_DECAY = 5e-4
_C.OPTIM.MOMENTUM = 0.9
_C.OPTIM.SGD_DAMPNING = 0
_C.OPTIM.SGD_NESTEROV = False
_C.OPTIM.RMSPROP_ALPHA = 0.99
_C.OPTIM.ADAM_BETA1 = 0.9
_C.OPTIM.ADAM_BETA2 = 0.999
# STAGED_LR allows different layers to have
# different lr, e.g. pre-trained base layers
# can be assigned a smaller lr than the new
# classification layer
_C.OPTIM.STAGED_LR = False
_C.OPTIM.NEW_LAYERS = ()
_C.OPTIM.BASE_LR_MULT = 0.1
# Learning rate scheduler
_C.OPTIM.LR_SCHEDULER = "single_step"
# -1 or 0 means the stepsize is equal to max_epoch
_C.OPTIM.STEPSIZE = (-1, )
_C.OPTIM.GAMMA = 0.1
_C.OPTIM.MAX_EPOCH = 10
# Set WARMUP_EPOCH larger than 0 to activate warmup training
_C.OPTIM.WARMUP_EPOCH = -1
# Either linear or constant
_C.OPTIM.WARMUP_TYPE = "linear"
# Constant learning rate when type=constant
_C.OPTIM.WARMUP_CONS_LR = 1e-5
# Minimum learning rate when type=linear
_C.OPTIM.WARMUP_MIN_LR = 1e-5
# Recount epoch for the next scheduler (last_epoch=-1)
# Otherwise last_epoch=warmup_epoch
_C.OPTIM.WARMUP_RECOUNT = True
###########################
# Train
###########################
_C.TRAIN = CN()
# How often (epoch) to save model during training
# Set to 0 or negative value to only save the last one
_C.TRAIN.CHECKPOINT_FREQ = 0
# How often (batch) to print training information
_C.TRAIN.PRINT_FREQ = 10
# Use 'train_x', 'train_u' or 'smaller_one' to count
# the number of iterations in an epoch (for DA and SSL)
_C.TRAIN.COUNT_ITER = "train_x"
###########################
# Test
###########################
_C.TEST = CN()
_C.TEST.EVALUATOR = "Classification"
_C.TEST.PER_CLASS_RESULT = False
# Compute confusion matrix, which will be saved
# to $OUTPUT_DIR/cmat.pt
_C.TEST.COMPUTE_CMAT = False
# If NO_TEST=True, no testing will be conducted
_C.TEST.NO_TEST = False
# Use test or val set for FINAL evaluation
_C.TEST.SPLIT = "test"
# Which model to test after training
# Either last_step or best_val
_C.TEST.FINAL_MODEL = "last_step"
###########################
# Trainer specifics
###########################
_C.TRAINER = CN()
_C.TRAINER.NAME = ""
# MCD
_C.TRAINER.MCD = CN()
_C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F
# MME
_C.TRAINER.MME = CN()
_C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
# SelfEnsembling
_C.TRAINER.SE = CN()
_C.TRAINER.SE.EMA_ALPHA = 0.999
_C.TRAINER.SE.CONF_THRE = 0.95
_C.TRAINER.SE.RAMPUP = 300
# M3SDA
_C.TRAINER.M3SDA = CN()
_C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
_C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD
# DAEL
_C.TRAINER.DAEL = CN()
_C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss
_C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold
_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()
# CrossGrad
_C.TRAINER.CG = CN()
_C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients
_C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients
_C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss
_C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss
# DDAIG
_C.TRAINER.DDAIG = CN()
_C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture
_C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight
_C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values
_C.TRAINER.DDAIG.CLAMP_MIN = -1.0
_C.TRAINER.DDAIG.CLAMP_MAX = 1.0
_C.TRAINER.DDAIG.WARMUP = 0
_C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses
# EntMin
_C.TRAINER.ENTMIN = CN()
_C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss
# Mean Teacher
_C.TRAINER.MEANTEA = CN()
_C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss
_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
_C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight
# MixMatch
_C.TRAINER.MIXMATCH = CN()
_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss
_C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability
_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
_C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight
# FixMatch
_C.TRAINER.FIXMATCH = CN()
_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss
_C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold
_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/__init__.py
================================================
from .data_manager import DataManager, DatasetWrapper
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/data_manager.py
================================================
import torch
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset as TorchDataset
from dassl.utils import read_image
from .datasets import build_dataset
from .samplers import build_sampler
from .transforms import build_transform
INTERPOLATION_MODES = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"nearest": Image.NEAREST,
}
def build_data_loader(
cfg,
sampler_type="SequentialSampler",
data_source=None,
batch_size=64,
n_domain=0,
n_ins=2,
tfm=None,
is_train=True,
dataset_wrapper=None,
):
# Build sampler
sampler = build_sampler(
sampler_type,
cfg=cfg,
data_source=data_source,
batch_size=batch_size,
n_domain=n_domain,
n_ins=n_ins,
)
if dataset_wrapper is None:
dataset_wrapper = DatasetWrapper
# Build data loader
data_loader = torch.utils.data.DataLoader(
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
batch_size=batch_size,
sampler=sampler,
num_workers=cfg.DATALOADER.NUM_WORKERS,
drop_last=is_train and len(data_source) >= batch_size,
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
)
assert len(data_loader) > 0
return data_loader
class DataManager:
def __init__(
self,
cfg,
custom_tfm_train=None,
custom_tfm_test=None,
dataset_wrapper=None
):
# Load dataset
dataset = build_dataset(cfg)
# Build transform
if custom_tfm_train is None:
tfm_train = build_transform(cfg, is_train=True)
else:
print("* Using custom transform for training")
tfm_train = custom_tfm_train
if custom_tfm_test is None:
tfm_test = build_transform(cfg, is_train=False)
else:
print("* Using custom transform for testing")
tfm_test = custom_tfm_test
# Build train_loader_x
train_loader_x = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
data_source=dataset.train_x,
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper,
)
# Build train_loader_u
train_loader_u = None
if dataset.train_u:
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
train_loader_u = build_data_loader(
cfg,
sampler_type=sampler_type_,
data_source=dataset.train_u,
batch_size=batch_size_,
n_domain=n_domain_,
n_ins=n_ins_,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper,
)
# Build val_loader
val_loader = None
if dataset.val:
val_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.val,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper,
)
# Build test_loader
test_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.test,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper,
)
# Attributes
self._num_classes = dataset.num_classes
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
self._lab2cname = dataset.lab2cname
# Dataset and data-loaders
self.dataset = dataset
self.train_loader_x = train_loader_x
self.train_loader_u = train_loader_u
self.val_loader = val_loader
self.test_loader = test_loader
if cfg.VERBOSE:
self.show_dataset_summary(cfg)
@property
def num_classes(self):
return self._num_classes
@property
def num_source_domains(self):
return self._num_source_domains
@property
def lab2cname(self):
return self._lab2cname
def show_dataset_summary(self, cfg):
print("***** Dataset statistics *****")
print(" Dataset: {}".format(cfg.DATASET.NAME))
if cfg.DATASET.SOURCE_DOMAINS:
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
if cfg.DATASET.TARGET_DOMAINS:
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
print(" # classes: {:,}".format(self.num_classes))
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
if self.dataset.train_u:
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
if self.dataset.val:
print(" # val: {:,}".format(len(self.dataset.val)))
print(" # test: {:,}".format(len(self.dataset.test)))
class DatasetWrapper(TorchDataset):
def __init__(self, cfg, data_source, transform=None, is_train=False):
self.cfg = cfg
self.data_source = data_source
self.transform = transform # accept list (tuple) as input
self.is_train = is_train
# Augmenting an image K>1 times is only allowed during training
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
if self.k_tfm > 1 and transform is None:
raise ValueError(
"Cannot augment the image {} times "
"because transform is None".format(self.k_tfm)
)
# Build transform that doesn't apply any data augmentation
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
to_tensor = []
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
to_tensor += [T.ToTensor()]
if "normalize" in cfg.INPUT.TRANSFORMS:
normalize = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
)
to_tensor += [normalize]
self.to_tensor = T.Compose(to_tensor)
def __len__(self):
return len(self.data_source)
def __getitem__(self, idx):
item = self.data_source[idx]
output = {
"label": item.label,
"domain": item.domain,
"impath": item.impath
}
img0 = read_image(item.impath)
if self.transform is not None:
if isinstance(self.transform, (list, tuple)):
for i, tfm in enumerate(self.transform):
img = self._transform_image(tfm, img0)
keyname = "img"
if (i + 1) > 1:
keyname += str(i + 1)
output[keyname] = img
else:
img = self._transform_image(self.transform, img0)
output["img"] = img
if self.return_img0:
output["img0"] = self.to_tensor(img0)
return output
def _transform_image(self, tfm, img0):
img_list = []
for k in range(self.k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
================================================
from .build import DATASET_REGISTRY, build_dataset # isort:skip
from .base_dataset import Datum, DatasetBase # isort:skip
from .da import *
from .dg import *
from .ssl import *
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
================================================
import os
import random
import os.path as osp
import tarfile
import zipfile
from collections import defaultdict
import gdown
from dassl.utils import check_isfile
class Datum:
"""Data instance which defines the basic attributes.
Args:
impath (str): image path.
label (int): class label.
domain (int): domain label.
classname (str): class name.
"""
def __init__(self, impath="", label=0, domain=0, classname=""):
assert isinstance(impath, str)
assert check_isfile(impath)
self._impath = impath
self._label = label
self._domain = domain
self._classname = classname
@property
def impath(self):
return self._impath
@property
def label(self):
return self._label
@property
def domain(self):
return self._domain
@property
def classname(self):
return self._classname
class DatasetBase:
"""A unified dataset class for
1) domain adaptation
2) domain generalization
3) semi-supervised learning
"""
dataset_dir = "" # the directory where the dataset is stored
domains = [] # string names of all domains
def __init__(self, train_x=None, train_u=None, val=None, test=None):
self._train_x = train_x # labeled training data
self._train_u = train_u # unlabeled training data (optional)
self._val = val # validation data (optional)
self._test = test # test data
self._num_classes = self.get_num_classes(train_x)
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
@property
def train_x(self):
return self._train_x
@property
def train_u(self):
return self._train_u
@property
def val(self):
return self._val
@property
def test(self):
return self._test
@property
def lab2cname(self):
return self._lab2cname
@property
def classnames(self):
return self._classnames
@property
def num_classes(self):
return self._num_classes
def get_num_classes(self, data_source):
"""Count number of classes.
Args:
data_source (list): a list of Datum objects.
"""
label_set = set()
for item in data_source:
label_set.add(item.label)
return max(label_set) + 1
def get_lab2cname(self, data_source):
"""Get a label-to-classname mapping (dict).
Args:
data_source (list): a list of Datum objects.
"""
container = set()
for item in data_source:
container.add((item.label, item.classname))
mapping = {label: classname for label, classname in container}
labels = list(mapping.keys())
labels.sort()
classnames = [mapping[label] for label in labels]
return mapping, classnames
def check_input_domains(self, source_domains, target_domains):
self.is_input_domain_valid(source_domains)
self.is_input_domain_valid(target_domains)
def is_input_domain_valid(self, input_domains):
for domain in input_domains:
if domain not in self.domains:
raise ValueError(
"Input domain must belong to {}, "
"but got [{}]".format(self.domains, domain)
)
def download_data(self, url, dst, from_gdrive=True):
if not osp.exists(osp.dirname(dst)):
os.makedirs(osp.dirname(dst))
if from_gdrive:
gdown.download(url, dst, quiet=False)
else:
raise NotImplementedError
print("Extracting file ...")
try:
tar = tarfile.open(dst)
tar.extractall(path=osp.dirname(dst))
tar.close()
except:
zip_ref = zipfile.ZipFile(dst, "r")
zip_ref.extractall(osp.dirname(dst))
zip_ref.close()
print("File extracted to {}".format(osp.dirname(dst)))
def generate_fewshot_dataset(
self, *data_sources, num_shots=-1, repeat=False
):
"""Generate a few-shot dataset (typically for the training set).
This function is useful when one wants to evaluate a model
in a few-shot learning setting where each class only contains
a few number of images.
Args:
data_sources: each individual is a list containing Datum objects.
num_shots (int): number of instances per class to sample.
repeat (bool): repeat images if needed (default: False).
"""
if num_shots < 1:
if len(data_sources) == 1:
return data_sources[0]
return data_sources
print(f"Creating a {num_shots}-shot dataset")
output = []
for data_source in data_sources:
tracker = self.split_dataset_by_label(data_source)
dataset = []
for label, items in tracker.items():
if len(items) >= num_shots:
sampled_items = random.sample(items, num_shots)
else:
if repeat:
sampled_items = random.choices(items, k=num_shots)
else:
sampled_items = items
dataset.extend(sampled_items)
output.append(dataset)
if len(output) == 1:
return output[0]
return output
def split_dataset_by_label(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into class-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.label].append(item)
return output
def split_dataset_by_domain(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into domain-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.domain].append(item)
return output
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
================================================
from dassl.utils import Registry, check_availability
DATASET_REGISTRY = Registry("DATASET")
def build_dataset(cfg):
avai_datasets = DATASET_REGISTRY.registered_names()
check_availability(cfg.DATASET.NAME, avai_datasets)
if cfg.VERBOSE:
print("Loading dataset: {}".format(cfg.DATASET.NAME))
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
================================================
from .digit5 import Digit5
from .visda17 import VisDA17
from .cifarstl import CIFARSTL
from .office31 import Office31
from .domainnet import DomainNet
from .office_home import OfficeHome
from .mini_domainnet import miniDomainNet
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
================================================
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFARSTL(DatasetBase):
"""CIFAR-10 and STL-10.
CIFAR-10:
- 60,000 32x32 colour images.
- 10 classes, with 6,000 images per class.
- 50,000 training images and 10,000 test images.
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
STL-10:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images (10 pre-defined folds), 800 test images
per class.
- URL: https://cs.stanford.edu/~acoates/stl10/.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "cifar_stl"
domains = ["cifar", "stl"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
data_dir = osp.join(self.dataset_dir, dname, split)
class_names = listdir_nohidden(data_dir)
for class_name in class_names:
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
label = int(class_name.split("_")[0])
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
================================================
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
items = random.sample(items, n_max)
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
n_repeat = 3 if split == "train" else None
return read_image_list(data_dir, n_repeat=n_repeat)
@DATASET_REGISTRY.register()
class Digit5(DatasetBase):
"""Five digit datasets.
It contains:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
the training set and 9,000 images from the test set. For USPS which has only
9,298 images in total, we use the entire dataset but replicate its training
set for 3 times so as to match the training set size of other domains.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
"""
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=str(label)
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
================================================
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DomainNet(DatasetBase):
"""DomainNet.
Statistics:
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
Real, Sketch.
- Around 0.6M images.
- 345 categories.
- URL: http://ai.bu.edu/M3SDA/.
Special note: the t-shirt class (327) is missing in painting_train.txt.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
"""
dataset_dir = "domainnet"
domains = [
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py
================================================
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class miniDomainNet(DatasetBase):
"""A subset of DomainNet.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
- Zhou et al. Domain Adaptive Ensemble Learning.
"""
dataset_dir = "domainnet"
domains = ["clipart", "painting", "real", "sketch"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
================================================
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class Office31(DatasetBase):
"""Office-31.
Statistics:
- 4,110 images.
- 31 classes related to office objects.
- 3 domains: Amazon, Webcam, Dslr.
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
Reference:
- Saenko et al. Adapting visual category models to
new domains. ECCV 2010.
"""
dataset_dir = "office31"
domains = ["amazon", "webcam", "dslr"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
================================================
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class OfficeHome(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home"
domains = ["art", "clipart", "product", "real_world"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name.lower(),
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
================================================
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VisDA17(DatasetBase):
"""VisDA17.
Focusing on simulation-to-reality domain shift.
URL: http://ai.bu.edu/visda-2017/.
Reference:
- Peng et al. VisDA: The Visual Domain Adaptation
Challenge. ArXiv 2017.
"""
dataset_dir = "visda17"
domains = ["synthetic", "real"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data("synthetic")
train_u = self._read_data("real")
test = self._read_data("real")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, dname):
filedir = "train" if dname == "synthetic" else "validation"
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
items = []
# There is only one source domain
domain = 0
with open(image_list, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[0]
impath = osp.join(self.dataset_dir, filedir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
================================================
from .pacs import PACS
from .vlcs import VLCS
from .cifar_c import CIFAR10C, CIFAR100C
from .digits_dg import DigitsDG
from .digit_single import DigitSingle
from .office_home_dg import OfficeHomeDG
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
================================================
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
AVAI_C_TYPES = [
"brightness",
"contrast",
"defocus_blur",
"elastic_transform",
"fog",
"frost",
"gaussian_blur",
"gaussian_noise",
"glass_blur",
"impulse_noise",
"jpeg_compression",
"motion_blur",
"pixelate",
"saturate",
"shot_noise",
"snow",
"spatter",
"speckle_noise",
"zoom_blur",
]
@DATASET_REGISTRY.register()
class CIFAR10C(DatasetBase):
"""CIFAR-10 -> CIFAR-10-C.
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
Statistics:
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar10", "cifar10_c"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = root
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
assert source_domain == self.domains[0]
assert target_domain == self.domains[1]
c_type = cfg.DATASET.CIFAR_C_TYPE
c_level = cfg.DATASET.CIFAR_C_LEVEL
if not c_type:
raise ValueError(
"Please specify DATASET.CIFAR_C_TYPE in the config file"
)
assert (
c_type in AVAI_C_TYPES
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
assert 1 <= c_level <= 5
train_dir = osp.join(self.dataset_dir, source_domain, "train")
test_dir = osp.join(
self.dataset_dir, target_domain, c_type, str(c_level)
)
if not osp.exists(test_dir):
raise ValueError
train = self._read_data(train_dir)
test = self._read_data(test_dir)
super().__init__(train_x=train, test=test)
def _read_data(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=0)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100C(CIFAR10C):
"""CIFAR-100 -> CIFAR-100-C.
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
Statistics:
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar100", "cifar100_c"]
def __init__(self, cfg):
super().__init__(cfg)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
================================================
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
# Note that the sampling process is NOT random,
# which follows that in Volpi et al. NIPS'18.
items = items[:n_max]
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
return read_image_list(data_dir)
@DATASET_REGISTRY.register()
class DigitSingle(DatasetBase):
"""Digit recognition datasets for single-source domain generalization.
There are five digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
Protocol:
Volpi et al. train a model using 10,000 images from MNIST and
evaluate the model on the test split of the other four datasets. However,
the code does not restrict you to only use MNIST as the source dataset.
Instead, you can use any dataset as the source. But note that only 10,000
images will be sampled from the source dataset for training.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
Augmentation. NIPS 2018.
"""
# Reuse the digit-5 folder instead of creating a new folder
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
================================================
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DigitsDG(DatasetBase):
"""Digits-DG.
It contains 4 digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
Generalisation. AAAI 2020.
"""
dataset_dir = "digits_dg"
domains = ["mnist", "mnist_m", "svhn", "syn"]
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "digits_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = self.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)
@staticmethod
def read_data(dataset_dir, input_domains, split):
def _load_data_from_directory(directory):
folders = listdir_nohidden(directory)
folders.sort()
items_ = []
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
for impath in impaths:
items_.append((impath, label))
return items_
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
train_dir = osp.join(dataset_dir, dname, "train")
impath_label_list = _load_data_from_directory(train_dir)
val_dir = osp.join(dataset_dir, dname, "val")
impath_label_list += _load_data_from_directory(val_dir)
else:
split_dir = osp.join(dataset_dir, dname, split)
impath_label_list = _load_data_from_directory(split_dir)
for impath, label in impath_label_list:
class_name = impath.split("/")[-2].lower()
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py
================================================
import os.path as osp
from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home_dg"
domains = ["art", "clipart", "product", "real_world"]
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "office_home_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
================================================
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class PACS(DatasetBase):
"""PACS.
Statistics:
- 4 domains: Photo (1,670), Art (2,048), Cartoon
(2,344), Sketch (3,929).
- 7 categories: dog, elephant, giraffe, guitar, horse,
house and person.
Reference:
- Li et al. Deeper, broader and artier domain generalization.
ICCV 2017.
"""
dataset_dir = "pacs"
domains = ["art_painting", "cartoon", "photo", "sketch"]
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
# the following images contain errors and should be ignored
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.image_dir = osp.join(self.dataset_dir, "images")
self.split_dir = osp.join(self.dataset_dir, "splits")
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "pacs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
file_train = osp.join(
self.split_dir, dname + "_train_kfold.txt"
)
impath_label_list = self._read_split_pacs(file_train)
file_val = osp.join(
self.split_dir, dname + "_crossval_kfold.txt"
)
impath_label_list += self._read_split_pacs(file_val)
else:
file = osp.join(
self.split_dir, dname + "_" + split + "_kfold.txt"
)
impath_label_list = self._read_split_pacs(file)
for impath, label in impath_label_list:
classname = impath.split("/")[-2]
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
def _read_split_pacs(self, split_file):
items = []
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
if impath in self._error_paths:
continue
impath = osp.join(self.image_dir, impath)
label = int(label) - 1
items.append((impath, label))
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
================================================
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
"""VLCS.
Statistics:
- 4 domains: CALTECH, LABELME, PASCAL, SUN
- 5 categories: bird, car, chair, dog, and person.
Reference:
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
"""
dataset_dir = "VLCS"
domains = ["caltech", "labelme", "pascal", "sun"]
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "vlcs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
dname = dname.upper()
path = osp.join(self.dataset_dir, dname, split)
folders = listdir_nohidden(path)
folders.sort()
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
for impath in impaths:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/__init__.py
================================================
from .svhn import SVHN
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
================================================
import math
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFAR10(DatasetBase):
"""CIFAR10 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
assert cfg.DATASET.NUM_LABELED > 0
train_x, train_u, val = self._read_data_train(
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
)
test = self._read_data_test(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
if len(val) == 0:
val = None
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data_train(self, data_dir, num_labeled, val_percent):
class_names = listdir_nohidden(data_dir)
class_names.sort()
num_labeled_per_class = num_labeled / len(class_names)
items_x, items_u, items_v = [], [], []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
# Split into train and val following Oliver et al. 2018
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
num_val = math.floor(len(imnames) * val_percent)
imnames_train = imnames[num_val:]
imnames_val = imnames[:num_val]
# Note we do shuffle after split
random.shuffle(imnames_train)
for i, imname in enumerate(imnames_train):
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
if (i + 1) <= num_labeled_per_class:
items_x.append(item)
else:
items_u.append(item)
for imname in imnames_val:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items_v.append(item)
return items_x, items_u, items_v
def _read_data_test(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100(CIFAR10):
"""CIFAR100 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar100"
def __init__(self, cfg):
super().__init__(cfg)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
================================================
import numpy as np
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class STL10(DatasetBase):
"""STL-10 dataset.
Description:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images per class, 800 test images per class.
- 100,000 unlabeled images for unsupervised learning.
Reference:
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "stl10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
fold_file = osp.join(
self.dataset_dir, "stl10_binary", "fold_indices.txt"
)
# Only use the first five splits
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
train_x = self._read_data_train(
train_dir, cfg.DATASET.STL10_FOLD, fold_file
)
train_u = self._read_data_all(unlabeled_dir)
test = self._read_data_all(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data_train(self, data_dir, fold, fold_file):
imnames = listdir_nohidden(data_dir)
imnames.sort()
items = []
list_idx = list(range(len(imnames)))
if fold >= 0:
with open(fold_file, "r") as f:
str_idx = f.read().splitlines()[fold]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
for i in list_idx:
imname = imnames[i]
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items
def _read_data_all(self, data_dir):
imnames = listdir_nohidden(data_dir)
items = []
for imname in imnames:
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
if label == "none":
label = -1
else:
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
================================================
from .cifar import CIFAR10
from ..build import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SVHN(CIFAR10):
"""SVHN for SSL.
Reference:
- Netzer et al. Reading Digits in Natural Images with
Unsupervised Feature Learning. NIPS-W 2011.
"""
dataset_dir = "svhn"
def __init__(self, cfg):
super().__init__(cfg)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/samplers.py
================================================
import copy
import numpy as np
import random
from collections import defaultdict
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
class RandomDomainSampler(Sampler):
"""Randomly samples N domains each with K images
to form a minibatch of size N*K.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
n_domain (int): number of domains to sample in a minibatch.
"""
def __init__(self, data_source, batch_size, n_domain):
self.data_source = data_source
# Keep track of image indices for each domain
self.domain_dict = defaultdict(list)
for i, item in enumerate(data_source):
self.domain_dict[item.domain].append(i)
self.domains = list(self.domain_dict.keys())
# Make sure each domain has equal number of images
if n_domain is None or n_domain <= 0:
n_domain = len(self.domains)
assert batch_size % n_domain == 0
self.n_img_per_domain = batch_size // n_domain
self.batch_size = batch_size
# n_domain denotes number of domains sampled in a minibatch
self.n_domain = n_domain
self.length = len(list(self.__iter__()))
def __iter__(self):
domain_dict = copy.deepcopy(self.domain_dict)
final_idxs = []
stop_sampling = False
while not stop_sampling:
selected_domains = random.sample(self.domains, self.n_domain)
for domain in selected_domains:
idxs = domain_dict[domain]
selected_idxs = random.sample(idxs, self.n_img_per_domain)
final_idxs.extend(selected_idxs)
for idx in selected_idxs:
domain_dict[domain].remove(idx)
remaining = len(domain_dict[domain])
if remaining < self.n_img_per_domain:
stop_sampling = True
return iter(final_idxs)
def __len__(self):
return self.length
class SeqDomainSampler(Sampler):
"""Sequential domain sampler, which randomly samples K
images from each domain to form a minibatch.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
"""
def __init__(self, data_source, batch_size):
self.data_source = data_source
# Keep track of image indices for each domain
self.domain_dict = defaultdict(list)
for i, item in enumerate(data_source):
self.domain_dict[item.domain].append(i)
self.domains = list(self.domain_dict.keys())
self.domains.sort()
# Make sure each domain has equal number of images
n_domain = len(self.domains)
assert batch_size % n_domain == 0
self.n_img_per_domain = batch_size // n_domain
self.batch_size = batch_size
# n_domain denotes number of domains sampled in a minibatch
self.n_domain = n_domain
self.length = len(list(self.__iter__()))
def __iter__(self):
domain_dict = copy.deepcopy(self.domain_dict)
final_idxs = []
stop_sampling = False
while not stop_sampling:
for domain in self.domains:
idxs = domain_dict[domain]
selected_idxs = random.sample(idxs, self.n_img_per_domain)
final_idxs.extend(selected_idxs)
for idx in selected_idxs:
domain_dict[domain].remove(idx)
remaining = len(domain_dict[domain])
if remaining < self.n_img_per_domain:
stop_sampling = True
return iter(final_idxs)
def __len__(self):
return self.length
class RandomClassSampler(Sampler):
"""Randomly samples N classes each with K instances to
form a minibatch of size N*K.
Modified from https://github.com/KaiyangZhou/deep-person-reid.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
n_ins (int): number of instances per class to sample in a minibatch.
"""
def __init__(self, data_source, batch_size, n_ins):
if batch_size < n_ins:
raise ValueError(
"batch_size={} must be no less "
"than n_ins={}".format(batch_size, n_ins)
)
self.data_source = data_source
self.batch_size = batch_size
self.n_ins = n_ins
self.ncls_per_batch = self.batch_size // self.n_ins
self.index_dic = defaultdict(list)
for index, item in enumerate(data_source):
self.index_dic[item.label].append(index)
self.labels = list(self.index_dic.keys())
assert len(self.labels) >= self.ncls_per_batch
# estimate number of images in an epoch
self.length = len(list(self.__iter__()))
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for label in self.labels:
idxs = copy.deepcopy(self.index_dic[label])
if len(idxs) < self.n_ins:
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.n_ins:
batch_idxs_dict[label].append(batch_idxs)
batch_idxs = []
avai_labels = copy.deepcopy(self.labels)
final_idxs = []
while len(avai_labels) >= self.ncls_per_batch:
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
for label in selected_labels:
batch_idxs = batch_idxs_dict[label].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[label]) == 0:
avai_labels.remove(label)
return iter(final_idxs)
def __len__(self):
return self.length
def build_sampler(
sampler_type,
cfg=None,
data_source=None,
batch_size=32,
n_domain=0,
n_ins=16
):
if sampler_type == "RandomSampler":
return RandomSampler(data_source)
elif sampler_type == "SequentialSampler":
return SequentialSampler(data_source)
elif sampler_type == "RandomDomainSampler":
return RandomDomainSampler(data_source, batch_size, n_domain)
elif sampler_type == "SeqDomainSampler":
return SeqDomainSampler(data_source, batch_size)
elif sampler_type == "RandomClassSampler":
return RandomClassSampler(data_source, batch_size, n_ins)
else:
raise ValueError("Unknown sampler type: {}".format(sampler_type))
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
================================================
from .transforms import build_transform
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
================================================
"""
Source: https://github.com/DeepVoltaire/AutoAugment
"""
import numpy as np
import random
from PIL import Image, ImageOps, ImageEnhance
class ImageNetPolicy:
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy:
"""Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy:
"""Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(
self,
p1,
operation1,
magnitude_idx1,
p2,
operation2,
magnitude_idx2,
fillcolor=(128, 128, 128),
):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10,
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(
rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
).convert(img.mode)
func = {
"shearX":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
1, 0
),
fillcolor=fillcolor,
),
"translateY":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1, 0, 0, 0, 1, magnitude * img.size[1] * random.
choice([-1, 1])
),
fillcolor=fillcolor,
),
"rotate":
lambda img, magnitude: rotate_with_fill(img, magnitude),
"color":
lambda img, magnitude: ImageEnhance.Color(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"posterize":
lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize":
lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast":
lambda img, magnitude: ImageEnhance.Contrast(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness":
lambda img, magnitude: ImageEnhance.Sharpness(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"brightness":
lambda img, magnitude: ImageEnhance.Brightness(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast":
lambda img, magnitude: ImageOps.autocontrast(img),
"equalize":
lambda img, magnitude: ImageOps.equalize(img),
"invert":
lambda img, magnitude: ImageOps.invert(img),
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
================================================
"""
Credit to
1) https://github.com/ildoonet/pytorch-randaugment
2) https://github.com/kakaobrain/fast-autoaugment
"""
import numpy as np
import random
import PIL
import torch
import PIL.ImageOps
import PIL.ImageDraw
import PIL.ImageEnhance
from PIL import Image
def ShearX(img, v):
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v):
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
def TranslateX(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[0]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateXabs(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[1]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def TranslateYabs(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def Rotate(img, v):
assert -30 <= v <= 30
if random.random() > 0.5:
v = -v
return img.rotate(v)
def AutoContrast(img, _):
return PIL.ImageOps.autocontrast(img)
def Invert(img, _):
return PIL.ImageOps.invert(img)
def Equalize(img, _):
return PIL.ImageOps.equalize(img)
def Flip(img, _):
return PIL.ImageOps.mirror(img)
def Solarize(img, v):
assert 0 <= v <= 256
return PIL.ImageOps.solarize(img, v)
def SolarizeAdd(img, addition=0, threshold=128):
img_np = np.array(img).astype(np.int)
img_np = img_np + addition
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def Posterize(img, v):
assert 4 <= v <= 8
v = int(v)
return PIL.ImageOps.posterize(img, v)
def Contrast(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Color(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Color(img).enhance(v)
def Brightness(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Sharpness(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def Cutout(img, v):
# [0, 60] => percentage: [0, 0.2]
assert 0.0 <= v <= 0.2
if v <= 0.0:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
def CutoutAbs(img, v):
# [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v/2.0))
y0 = int(max(0, y0 - v/2.0))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def SamplePairing(imgs):
# [0, 0.4]
def f(img1, v):
i = np.random.choice(len(imgs))
img2 = PIL.Image.fromarray(imgs[i])
return PIL.Image.blend(img1, img2, v)
return f
def Identity(img, v):
return img
class Lighting:
"""Lighting noise (AlexNet - style PCA - based noise)."""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = torch.Tensor(eigval)
self.eigvec = torch.Tensor(eigvec)
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = (
self.eigvec.type_as(img).clone().mul(
alpha.view(1, 3).expand(3, 3)
).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
)
return img.add(rgb.view(3, 1, 1).expand_as(img))
class CutoutDefault:
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def randaugment_list():
# 16 oeprations and their ranges
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
# augs = [
# (Identity, 0., 1.0),
# (ShearX, 0., 0.3), # 0
# (ShearY, 0., 0.3), # 1
# (TranslateX, 0., 0.33), # 2
# (TranslateY, 0., 0.33), # 3
# (Rotate, 0, 30), # 4
# (AutoContrast, 0, 1), # 5
# (Invert, 0, 1), # 6
# (Equalize, 0, 1), # 7
# (Solarize, 0, 110), # 8
# (Posterize, 4, 8), # 9
# # (Contrast, 0.1, 1.9), # 10
# (Color, 0.1, 1.9), # 11
# (Brightness, 0.1, 1.9), # 12
# (Sharpness, 0.1, 1.9), # 13
# # (Cutout, 0, 0.2), # 14
# # (SamplePairing(imgs), 0, 0.4) # 15
# ]
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
augs = [
(AutoContrast, 0, 1),
(Equalize, 0, 1),
(Invert, 0, 1),
(Rotate, 0, 30),
(Posterize, 4, 8),
(Solarize, 0, 256),
(SolarizeAdd, 0, 110),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Brightness, 0.1, 1.9),
(Sharpness, 0.1, 1.9),
(ShearX, 0.0, 0.3),
(ShearY, 0.0, 0.3),
(CutoutAbs, 0, 40),
(TranslateXabs, 0.0, 100),
(TranslateYabs, 0.0, 100),
]
return augs
def randaugment_list2():
augs = [
(AutoContrast, 0, 1),
(Brightness, 0.1, 1.9),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Equalize, 0, 1),
(Identity, 0, 1),
(Invert, 0, 1),
(Posterize, 4, 8),
(Rotate, -30, 30),
(Sharpness, 0.1, 1.9),
(ShearX, -0.3, 0.3),
(ShearY, -0.3, 0.3),
(Solarize, 0, 256),
(TranslateX, -0.3, 0.3),
(TranslateY, -0.3, 0.3),
]
return augs
def fixmatch_list():
# https://arxiv.org/abs/2001.07685
augs = [
(AutoContrast, 0, 1),
(Brightness, 0.05, 0.95),
(Color, 0.05, 0.95),
(Contrast, 0.05, 0.95),
(Equalize, 0, 1),
(Identity, 0, 1),
(Posterize, 4, 8),
(Rotate, -30, 30),
(Sharpness, 0.05, 0.95),
(ShearX, -0.3, 0.3),
(ShearY, -0.3, 0.3),
(Solarize, 0, 256),
(TranslateX, -0.3, 0.3),
(TranslateY, -0.3, 0.3),
]
return augs
class RandAugment:
def __init__(self, n=2, m=10):
assert 0 <= m <= 30
self.n = n
self.m = m
self.augment_list = randaugment_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
val = (self.m / 30) * (maxval-minval) + minval
img = op(img, val)
return img
class RandAugment2:
def __init__(self, n=2, p=0.6):
self.n = n
self.p = p
self.augment_list = randaugment_list2()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
if random.random() > self.p:
continue
m = random.random()
val = m * (maxval-minval) + minval
img = op(img, val)
return img
class RandAugmentFixMatch:
def __init__(self, n=2):
self.n = n
self.augment_list = fixmatch_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
m = random.random()
val = m * (maxval-minval) + minval
img = op(img, val)
return img
================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
================================================
import numpy as np
import random
import torch
from PIL import Image
from torchvision.transforms import (
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
RandomHorizontalFlip
)
from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch
AVAI_CHOICES = [
"random_flip",
"random_resized_crop",
"normalize",
"instance_norm",
"random_crop",
"random_translation",
"center_crop", # This has become a default operation for test
"cutout",
"imagenet_policy",
"cifar10_policy",
"svhn_policy",
"randaugment",
"randaugment_fixmatch",
"randaugment2",
"gaussian_noise",
"colorjitter",
"randomgrayscale",
"gaussian_blur",
]
INTERPOLATION_MODES = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"nearest": Image.NEAREST,
}
class Random2DTranslation:
"""Given an image of (height, width), we resize it to
(height*1.125, width*1.125), and then perform random cropping.
Args:
height (int): target image height.
width (int): target image width.
p (float, optional): probability that this operation takes place.
Default is 0.5.
interpolation (int, optional): desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
self.height = height
self.width = width
self.p = p
self.interpolation = interpolation
def __call__(self, img):
if random.uniform(0, 1) > self.p:
return img.resize((self.width, self.height), self.interpolation)
new_width = int(round(self.width * 1.125))
new_height = int(round(self.height * 1.125))
resized_img = img.resize((new_width, new_height), self.interpolation)
x_maxrange = new_width - self.width
y_maxrange = new_height - self.height
x1 = int(round(random.uniform(0, x_maxrange)))
y1 = int(round(random.uniform(0, y_maxrange)))
croped_img = resized_img.crop(
(x1, y1, x1 + self.width, y1 + self.height)
)
return croped_img
class InstanceNormalization:
"""Normalize data using per-channel mean and standard deviation.
Reference:
- Ulyanov et al. Instance normalization: The missing in- gredient
for fast stylization. ArXiv 2016.
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
ICLR 2018.
"""
def __init__(self, eps=1e-8):
self.eps = eps
def __call__(self, img):
C, H, W = img.shape
img_re = img.reshape(C, H * W)
mean = img_re.mean(1).view(C, 1, 1)
std = img_re.std(1).view(C, 1, 1)
return (img-mean) / (std + self.eps)
class Cutout:
"""Randomly mask out one or more patches from an image.
https://github.com/uoguelph-mlrg/Cutout
Args:
n_holes (int, optional): number of patches to cut out
of each image. Default is 1.
length (int, optinal): length (in pixels) of each square
patch. Default is 16.
"""
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): tensor image of size (C, H, W).
Returns:
Tensor: image with n_holes of dimension
length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
return img * mask
class GaussianNoise:
"""Add gaussian noise."""
def __init__(self, mean=0, std=0.15, p=0.5):
self.mean = mean
self.std = std
self.p = p
def __call__(self, img):
if random.uniform(0, 1) > self.p:
return img
noise = torch.randn(img.size()) * self.std + self.mean
return img + noise
def build_transform(cfg, is_train=True, choices=None):
"""Build transformation function.
Args:
cfg (CfgNode): config.
is_train (bool, optional): for training (True) or test (False).
Default is True.
choices (list, optional): list of strings which will overwrite
cfg.INPUT.TRANSFORMS if given. Default is None.
"""
if cfg.INPUT.NO_TRANSFORM:
print("Note: no transform is applied!")
return None
if choices is None:
choices = cfg.INPUT.TRANSFORMS
for choice in choices:
assert choice in AVAI_CHOICES
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
if is_train:
return _build_transform_train(cfg, choices, target_size, normalize)
else:
return _build_transform_test(cfg, choices, target_size, normalize)
def _build_transform_train(cfg, choices, target_size, normalize):
print("Building transform_train")
tfm_train = []
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
# Make sure the image size matches the target size
conditions = []
conditions += ["random_crop" not in choices]
conditions += ["random_resized_crop" not in choices]
if all(conditions):
print(f"+ resize to {target_size}")
tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
if "random_translation" in choices:
print("+ random translation")
tfm_train += [
Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
]
if "random_crop" in choices:
crop_padding = cfg.INPUT.CROP_PADDING
print("+ random crop (padding = {})".format(crop_padding))
tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]
if "random_resized_crop" in choices:
print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
tfm_train += [
RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
]
if "center_crop" in choices:
print(f"+ center crop (size={cfg.INPUT.SIZE})")
tfm_train += [CenterCrop(cfg.INPUT.SIZE)]
if "random_flip" in choices:
print("+ random flip")
tfm_train += [RandomHorizontalFlip()]
if "imagenet_policy" in choices:
print("+ imagenet policy")
tfm_train += [ImageNetPolicy()]
if "cifar10_policy" in choices:
print("+ cifar10 policy")
tfm_train += [CIFAR10Policy()]
if "svhn_policy" in choices:
print("+ svhn policy")
tfm_train += [SVHNPolicy()]
if "randaugment" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
m_ = cfg.INPUT.RANDAUGMENT_M
print("+ randaugment (n={}, m={})".format(n_, m_))
tfm_train += [RandAugment(n_, m_)]
if "randaugment_fixmatch" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print("+ randaugment_fixmatch (n={})".format(n_))
tfm_train += [RandAugmentFixMatch(n_)]
if "randaugment2" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print("+ randaugment2 (n={})".format(n_))
tfm_train += [RandAugment2(n_)]
if "colorjitter" in choices:
print("+ color jitter")
tfm_train += [
ColorJitter(
brightness=cfg.INPUT.COLORJITTER_B,
contrast=cfg.INPUT.COLORJITTER_C,
saturation=cfg.INPUT.COLORJITTER_S,
hue=cfg.INPUT.COLORJITTER_H,
)
]
if "randomgrayscale" in choices:
print("+ random gray scale")
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
if "gaussian_blur" in choices:
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
tfm_train += [
RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
]
print("+ to torch tensor of range [0, 1]")
tfm_train += [ToTensor()]
if "cutout" in choices:
cutout_n = cfg.INPUT.CUTOUT_N
cutout_len = cfg.INPUT.CUTOUT_LEN
print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
tfm_train += [Cutout(cutout_n, cutout_len)]
if "normalize" in choices:
print(
"+ normalization (mean={}, "
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
)
tfm_train += [normalize]
if "gaussian_noise" in choices:
print(
"+ gaussian noise (mean={}, std={})".format(
cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
)
)
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
if "instance_norm" in choices:
print("+ instance normalization")
tfm_train += [InstanceNormalization()]
tfm_train = Compose(tfm_train)
return tfm_train
def _build_transform_test(cfg, choices, target_size, normalize):
print("Building transform_test")
tfm_test = []
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]
print(f"+ {target_size} center crop")
tfm_test += [CenterCrop(cfg.INPUT.SIZE)]
print("+ to torch tensor of range [0, 1]")
tfm_test += [ToTensor()]
if "normalize" in choices:
print(
"+ normalization (mean={}, "
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
)
tfm_test += [normalize]
if "instance_norm" in choices:
print("+ instance normalization")
tfm_test += [InstanceNormalization()]
tfm_test = Compose(tfm_test)
return tfm_test
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/__init__.py
================================================
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
from .da import *
from .dg import *
from .ssl import *
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/build.py
================================================
from dassl.utils import Registry, check_availability
TRAINER_REGISTRY = Registry("TRAINER")
def build_trainer(cfg):
avai_trainers = TRAINER_REGISTRY.registered_names()
check_availability(cfg.TRAINER.NAME, avai_trainers)
if cfg.VERBOSE:
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
================================================
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
from .m3sda import M3SDA
from .source_only import SourceOnly
from .self_ensembling import SelfEnsembling
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
================================================
import torch
from dassl.utils import check_isfile
from dassl.engine import TRAINER_REGISTRY, TrainerXU
@TRAINER_REGISTRY.register()
class AdaBN(TrainerXU):
"""Adaptive Batch Normalization.
https://arxiv.org/abs/1603.04779.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.done_reset_bn_stats = False
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def before_epoch(self):
if not self.done_reset_bn_stats:
for m in self.model.modules():
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.reset_running_stats()
self.done_reset_bn_stats = True
def forward_backward(self, batch_x, batch_u):
input_u = batch_u["img"].to(self.device)
with torch.no_grad():
self.model(input_u)
return None
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
================================================
import copy
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head
@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
"""Adversarial Discriminative Domain Adaptation.
https://arxiv.org/abs/1702.05464.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.open_layers = ["backbone"]
if isinstance(self.model.head, nn.Module):
self.open_layers.append("head")
self.source_model = copy.deepcopy(self.model)
self.source_model.eval()
for param in self.source_model.parameters():
param.requires_grad_(False)
self.build_critic()
self.bce = nn.BCEWithLogitsLoss()
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim // 2],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
def forward_backward(self, batch_x, batch_u):
open_specified_layers(self.model, self.open_layers)
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
_, feat_x = self.source_model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
logit_xd = self.critic(feat_x)
logit_ud = self.critic(feat_u.detach())
loss_critic = self.bce(logit_xd, domain_x)
loss_critic += self.bce(logit_ud, domain_u)
self.model_backward_and_update(loss_critic, "critic")
logit_ud = self.critic(feat_u)
loss_model = self.bce(logit_ud, 1 - domain_u)
self.model_backward_and_update(loss_model, "model")
loss_summary = {
"loss_critic": loss_critic.item(),
"loss_model": loss_model.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
================================================
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAEL(TrainerXU):
"""Domain Adaptive Ensemble Learning.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
input_x = torch.split(input_x, self.split_batch, 0)
input_x2 = torch.split(input_x2, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Generate pseudo label
with torch.no_grad():
feat_u = self.F(input_u)
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1) # (B, K, C)
# Get the highest probability and index (label) for each expert
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
# Get the most confident expert
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
pseudo_label_u = []
for i, experts_label in zip(max_expert_idx, experts_max_idx):
pseudo_label_u.append(experts_label[i])
pseudo_label_u = torch.stack(pseudo_label_u, 0)
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
pseudo_label_u = pseudo_label_u.to(self.device)
label_u_mask = (max_expert_p >= self.conf_thre).float()
loss_x = 0
loss_cr = 0
acc_x = 0
feat_x = [self.F(x) for x in input_x]
feat_x2 = [self.F(x) for x in input_x2]
feat_u2 = self.F(input_u2)
for feat_xi, feat_x2i, label_xi, i in zip(
feat_x, feat_x2, label_x, domain_x
):
cr_s = [j for j in domain_x if j != i]
# Learning expert
pred_xi = self.E(i, feat_xi)
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
expert_label_xi = pred_xi.detach()
acc_x += compute_accuracy(pred_xi.detach(),
label_xi.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat_x2i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc_x /= self.n_domain
# Unsupervised loss
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u2)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1)
pred_u = pred_u.mean(1)
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
loss_u = (l_u * label_u_mask).mean()
loss = 0
loss += loss_x
loss += loss_cr
loss += loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": acc_x,
"loss_cr": loss_cr.item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
label_x = create_onehot(label_x, self.num_classes)
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, input_x2, label_x, domain_x, input_u, input_u2
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
================================================
import numpy as np
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad
@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
"""Domain-Adversarial Neural Networks.
https://arxiv.org/abs/1505.07818.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.build_critic()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
global_step = self.batch_idx + self.epoch * self.num_batches
progress = global_step / (self.max_epoch * self.num_batches)
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
logit_x, feat_x = self.model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
loss_x = self.ce(logit_x, label_x)
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
output_xd = self.critic(feat_x)
output_ud = self.critic(feat_u)
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
loss = loss_x + loss_d
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
class PairClassifiers(nn.Module):
def __init__(self, fdim, num_classes):
super().__init__()
self.c1 = nn.Linear(fdim, num_classes)
self.c2 = nn.Linear(fdim, num_classes)
def forward(self, x):
z1 = self.c1(x)
if not self.training:
return z1
z2 = self.c2(x)
return z1, z2
@TRAINER_REGISTRY.register()
class M3SDA(TrainerXU):
"""Moment Matching for Multi-Source Domain Adaptation.
https://arxiv.org/abs/1812.01754.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
self.lmda = cfg.TRAINER.M3SDA.LMDA
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C")
self.C = nn.ModuleList(
[
PairClassifiers(fdim, self.num_classes)
for _ in range(self.num_source_domains)
]
)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, domain_x, input_u = parsed
input_x = torch.split(input_x, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Step A
loss_x = 0
feat_x = []
for x, y, d in zip(input_x, label_x, domain_x):
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
feat_x.append(f)
loss_x /= self.n_domain
feat_u = self.F(input_u)
loss_msda = self.moment_distance(feat_x, feat_u)
loss_step_A = loss_x + loss_msda * self.lmda
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_u = self.F(input_u)
loss_x, loss_dis = 0, 0
for x, y, d in zip(input_x, label_x, domain_x):
with torch.no_grad():
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_x /= self.n_domain
loss_dis /= self.n_domain
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, "C")
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
loss_dis = 0
for d in domain_x:
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_dis /= self.n_domain
loss_step_C = loss_dis
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def moment_distance(self, x, u):
# x (list): a list of feature matrix.
# u (torch.Tensor): feature matrix.
x_mean = [xi.mean(0) for xi in x]
u_mean = u.mean(0)
dist1 = self.pairwise_distance(x_mean, u_mean)
x_var = [xi.var(0) for xi in x]
u_var = u.var(0)
dist2 = self.pairwise_distance(x_var, u_var)
return (dist1+dist2) / 2
def pairwise_distance(self, x, u):
# x (list): a list of feature vector.
# u (torch.Tensor): feature vector.
dist = 0
count = 0
for xi in x:
dist += self.euclidean(xi, u)
count += 1
for i in range(len(x) - 1):
for j in range(i + 1, len(x)):
dist += self.euclidean(x[i], x[j])
count += 1
return dist / count
def euclidean(self, input1, input2):
return ((input1 - input2)**2).sum().sqrt()
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, domain_x, input_u
def model_inference(self, input):
f = self.F(input)
p = 0
for C_i in self.C:
z = C_i(f)
p += F.softmax(z, 1)
p = p / len(self.C)
return p
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
"""Maximum Classifier Discrepancy.
https://arxiv.org/abs/1712.02560.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C1")
self.C1 = nn.Linear(fdim, self.num_classes)
self.C1.to(self.device)
print("# params: {:,}".format(count_num_param(self.C1)))
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
print("Building C2")
self.C2 = nn.Linear(fdim, self.num_classes)
self.C2.to(self.device)
print("# params: {:,}".format(count_num_param(self.C2)))
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u = parsed
# Step A
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_step_A = loss_x1 + loss_x2
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_x = loss_x1 + loss_x2
with torch.no_grad():
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_dis = self.discrepancy(pred_u1, pred_u2)
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_step_C = self.discrepancy(pred_u1, pred_u2)
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def model_inference(self, input):
feat = self.F(input)
return self.C1(feat)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
class Prototypes(nn.Module):
def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
def forward(self, x):
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out
@TRAINER_REGISTRY.register()
class MME(TrainerXU):
"""Minimax Entropy.
https://arxiv.org/abs/1904.06487.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.MME.LMDA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)
self.model_backward_and_update(loss_x)
feat_u = self.F(input_u)
feat_u = self.revgrad(feat_u)
logit_u = self.C(feat_u)
prob_u = F.softmax(logit_u, 1)
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
self.model_backward_and_update(loss_u * self.lmda)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.C(self.F(input))
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
================================================
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class SelfEnsembling(TrainerXU):
"""Self-ensembling for visual domain adaptation.
https://arxiv.org/abs/1706.05208.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
self.rampup = cfg.TRAINER.SE.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS == 2
def forward_backward(self, batch_x, batch_u):
global_step = self.batch_idx + self.epoch * self.num_batches
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u1, input_u2 = parsed
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
prob_u = F.softmax(self.model(input_u1), 1)
t_prob_u = F.softmax(self.teacher(input_u2), 1)
loss_u = ((prob_u - t_prob_u)**2).sum(1)
if self.conf_thre:
max_prob = t_prob_u.max(1)[0]
mask = (max_prob > self.conf_thre).float()
loss_u = (loss_u * mask).mean()
else:
weight_u = sigmoid_rampup(global_step, self.rampup)
loss_u = loss_u.mean() * weight_u
loss = loss_x + loss_u
self.model_backward_and_update(loss)
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u1, input_u2 = input_u
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u1 = input_u1.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, label_x, input_u1, input_u2
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
================================================
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SourceOnly(TrainerXU):
"""Baseline model for domain adaptation, which is
trained using source data only.
"""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
================================================
from .ddaig import DDAIG
from .daeldg import DAELDG
from .vanilla import Vanilla
from .crossgrad import CrossGrad
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
================================================
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class CrossGrad(TrainerX):
"""Cross-gradient training.
https://arxiv.org/abs/1804.10745.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.eps_f = cfg.TRAINER.CG.EPS_F
self.eps_d = cfg.TRAINER.CG.EPS_D
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
input.requires_grad = True
# Compute domain perturbation
loss_d = F.cross_entropy(self.D(input), domain)
loss_d.backward()
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_d = input.data + self.eps_f * grad_d
# Compute label perturbation
input.grad.data.zero_()
loss_f = F.cross_entropy(self.F(input), label)
loss_f.backward()
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_f = input.data + self.eps_d * grad_f
input = input.detach()
# Update label net
loss_f1 = F.cross_entropy(self.F(input), label)
loss_f2 = F.cross_entropy(self.F(input_d), label)
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
self.model_backward_and_update(loss_f, "F")
# Update domain net
loss_d1 = F.cross_entropy(self.D(input), domain)
loss_d2 = F.cross_entropy(self.D(input_f), domain)
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
self.model_backward_and_update(loss_d, "D")
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
================================================
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAELDG(TrainerX):
"""Domain Adaptive Ensemble Learning.
DG version: only use labeled source data.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch):
parsed_data = self.parse_batch_train(batch)
input, input2, label, domain = parsed_data
input = torch.split(input, self.split_batch, 0)
input2 = torch.split(input2, self.split_batch, 0)
label = torch.split(label, self.split_batch, 0)
domain = torch.split(domain, self.split_batch, 0)
domain = [d[0].item() for d in domain]
loss_x = 0
loss_cr = 0
acc = 0
feat = [self.F(x) for x in input]
feat2 = [self.F(x) for x in input2]
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
cr_s = [j for j in domain if j != i]
# Learning expert
pred_i = self.E(i, feat_i)
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
expert_label_i = pred_i.detach()
acc += compute_accuracy(pred_i.detach(),
label_i.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat2_i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss_x
loss += loss_cr
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc": acc,
"loss_cr": loss_cr.item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
input2 = batch["img2"]
label = batch["label"]
domain = batch["domain"]
label = create_onehot(label, self.num_classes)
input = input.to(self.device)
input2 = input2.to(self.device)
label = label.to(self.device)
return input, input2, label, domain
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
================================================
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.modeling import build_network
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class DDAIG(TrainerX):
"""Deep Domain-Adversarial Image Generation.
https://arxiv.org/abs/2003.06054.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.DDAIG.LMDA
self.clamp = cfg.TRAINER.DDAIG.CLAMP
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
self.warmup = cfg.TRAINER.DDAIG.WARMUP
self.alpha = cfg.TRAINER.DDAIG.ALPHA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
print("Building G")
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
self.G.to(self.device)
print("# params: {:,}".format(count_num_param(self.G)))
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
self.register_model("G", self.G, self.optim_G, self.sched_G)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
#############
# Update G
#############
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
loss_g = 0
# Minimize label loss
loss_g += F.cross_entropy(self.F(input_p), label)
# Maximize domain loss
loss_g -= F.cross_entropy(self.D(input_p), domain)
self.model_backward_and_update(loss_g, "G")
# Perturb data with new G
with torch.no_grad():
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
#############
# Update F
#############
loss_f = F.cross_entropy(self.F(input), label)
if (self.epoch + 1) > self.warmup:
loss_fp = F.cross_entropy(self.F(input_p), label)
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
self.model_backward_and_update(loss_f, "F")
#############
# Update D
#############
loss_d = F.cross_entropy(self.D(input), domain)
self.model_backward_and_update(loss_d, "D")
loss_summary = {
"loss_g": loss_g.item(),
"loss_f": loss_f.item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
================================================
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class Vanilla(TrainerX):
"""Vanilla baseline."""
def forward_backward(self, batch):
input, label = self.parse_batch_train(batch)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
================================================
from .entmin import EntMin
from .fixmatch import FixMatch
from .mixmatch import MixMatch
from .mean_teacher import MeanTeacher
from .sup_baseline import SupBaseline
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
================================================
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class EntMin(TrainerXU):
"""Entropy Minimization.
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.ENTMIN.LMDA
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
output_u = F.softmax(self.model(input_u), 1)
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
loss = loss_x + loss_u * self.lmda
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
================================================
import torch
from torch.nn import functional as F
from dassl.data import DataManager
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.data.transforms import build_transform
@TRAINER_REGISTRY.register()
class FixMatch(TrainerXU):
"""FixMatch: Simplifying Semi-Supervised Learning with
Consistency and Confidence.
https://arxiv.org/abs/2001.07685.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
def check_cfg(self, cfg):
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
input_u = torch.cat([input_x, input_u], 0)
input_u2 = torch.cat([input_x2, input_u2], 0)
n_x = input_x.size(0)
# Generate pseudo labels
with torch.no_grad():
output_u = F.softmax(self.model(input_u), 1)
max_prob, label_u_pred = output_u.max(1)
mask_u = (max_prob >= self.conf_thre).float()
# Evaluate pseudo labels' accuracy
y_u_pred_stats = self.assess_y_pred_quality(
label_u_pred[n_x:], label_u, mask_u[n_x:]
)
# Supervised loss
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
# Unsupervised loss
output_u = self.model(input_u2)
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
loss_u = (loss_u * mask_u).mean()
loss = loss_x + loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
# label_u is used only for evaluating pseudo labels' accuracy
label_u = batch_u["label"]
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
label_u = label_u.to(self.device)
return input_x, input_x2, label_x, input_u, input_u2, label_u
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
================================================
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class MeanTeacher(TrainerXU):
"""Mean teacher.
https://arxiv.org/abs/1703.01780.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
target_u = F.softmax(self.teacher(input_u), 1)
prob_u = F.softmax(self.model(input_u), 1)
loss_u = ((prob_u - target_u)**2).sum(1).mean()
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
global_step = self.batch_idx + self.epoch * self.num_batches
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
================================================
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling.ops import mixup
from dassl.modeling.ops.utils import (
sharpen_prob, create_onehot, linear_rampup, shuffle_index
)
@TRAINER_REGISTRY.register()
class MixMatch(TrainerXU):
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
https://arxiv.org/abs/1905.02249.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
self.temp = cfg.TRAINER.MIXMATCH.TEMP
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS > 1
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
num_x = input_x.shape[0]
global_step = self.batch_idx + self.epoch * self.num_batches
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
# Generate pseudo-label for unlabeled data
with torch.no_grad():
output_u = 0
for input_ui in input_u:
output_ui = F.softmax(self.model(input_ui), 1)
output_u += output_ui
output_u /= len(input_u)
label_u = sharpen_prob(output_u, self.temp)
label_u = [label_u] * len(input_u)
label_u = torch.cat(label_u, 0)
input_u = torch.cat(input_u, 0)
# Combine and shuffle labeled and unlabeled data
input_xu = torch.cat([input_x, input_u], 0)
label_xu = torch.cat([label_x, label_u], 0)
input_xu, label_xu = shuffle_index(input_xu, label_xu)
# Mixup
input_x, label_x = mixup(
input_x,
input_xu[:num_x],
label_x,
label_xu[:num_x],
self.beta,
preserve_order=True,
)
input_u, label_u = mixup(
input_u,
input_xu[num_x:],
label_u,
label_xu[num_x:],
self.beta,
preserve_order=True,
)
# Compute losses
output_x = F.softmax(self.model(input_x), 1)
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
output_u = F.softmax(self.model(input_u), 1)
loss_u = ((label_u - output_u)**2).mean()
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
label_x = create_onehot(label_x, self.num_classes)
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = [input_ui.to(self.device) for input_ui in input_u]
return input_x, label_x, input_u
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
================================================
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SupBaseline(TrainerXU):
"""Supervised Baseline."""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/trainer.py
================================================
import json
import time
import numpy as np
import os.path as osp
import datetime
from collections import OrderedDict
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import (
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
load_pretrained_weights
)
from dassl.modeling import build_head, build_backbone
from dassl.evaluation import build_evaluator
class SimpleNet(nn.Module):
"""A simple neural network composed of a CNN backbone
and optionally a head such as mlp for classification.
"""
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
super().__init__()
self.backbone = build_backbone(
model_cfg.BACKBONE.NAME,
verbose=cfg.VERBOSE,
pretrained=model_cfg.BACKBONE.PRETRAINED,
**kwargs,
)
fdim = self.backbone.out_features
self.head = None
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
self.head = build_head(
model_cfg.HEAD.NAME,
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
activation=model_cfg.HEAD.ACTIVATION,
bn=model_cfg.HEAD.BN,
dropout=model_cfg.HEAD.DROPOUT,
**kwargs,
)
fdim = self.head.out_features
self.classifier = None
if num_classes > 0:
self.classifier = nn.Linear(fdim, num_classes)
self._fdim = fdim
@property
def fdim(self):
return self._fdim
def forward(self, x, return_feature=False):
f = self.backbone(x)
if self.head is not None:
f = self.head(f)
if self.classifier is None:
return f
y = self.classifier(f)
if return_feature:
return y, f
return y
class TrainerBase:
"""Base class for iterative trainer."""
def __init__(self):
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
self._writer = None
def register_model(self, name="model", model=None, optim=None, sched=None):
if self.__dict__.get("_models") is None:
raise AttributeError(
"Cannot assign model before super().__init__() call"
)
if self.__dict__.get("_optims") is None:
raise AttributeError(
"Cannot assign optim before super().__init__() call"
)
if self.__dict__.get("_scheds") is None:
raise AttributeError(
"Cannot assign sched before super().__init__() call"
)
assert name not in self._models, "Found duplicate model names"
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
names = tolist_if_not(names)
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(self, epoch, directory, is_best=False, model_name=""):
names = self.get_model_names()
for name in names:
model_dict = self._models[name].state_dict()
optim_dict = None
if self._optims[name] is not None:
optim_dict = self._optims[name].state_dict()
sched_dict = None
if self._scheds[name] is not None:
sched_dict = self._scheds[name].state_dict()
save_checkpoint(
{
"state_dict": model_dict,
"epoch": epoch + 1,
"optimizer": optim_dict,
"scheduler": sched_dict,
},
osp.join(directory, name),
is_best=is_best,
model_name=model_name,
)
def resume_model_if_exist(self, directory):
names = self.get_model_names()
file_missing = False
for name in names:
path = osp.join(directory, name)
if not osp.exists(path):
file_missing = True
break
if file_missing:
print("No checkpoint found, train from scratch")
return 0
print(
'Found checkpoint in "{}". Will resume training'.format(directory)
)
for name in names:
path = osp.join(directory, name)
start_epoch = resume_from_checkpoint(
path, self._models[name], self._optims[name],
self._scheds[name]
)
return start_epoch
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained "
"model is given (ignore this if it's done on purpose)"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path)
)
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
print(
"Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch)
)
self._models[name].load_state_dict(state_dict)
def set_model_mode(self, mode="train", names=None):
names = self.get_model_names(names)
for name in names:
if mode == "train":
self._models[name].train()
elif mode in ["test", "eval"]:
self._models[name].eval()
else:
raise KeyError
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def detect_anomaly(self, loss):
if not torch.isfinite(loss).all():
raise FloatingPointError("Loss is infinite or NaN!")
def init_writer(self, log_dir):
if self.__dict__.get("_writer") is None or self._writer is None:
print(
"Initializing summary writer for tensorboard "
"with log_dir={}".format(log_dir)
)
self._writer = SummaryWriter(log_dir=log_dir)
def close_writer(self):
if self._writer is not None:
self._writer.close()
def write_scalar(self, tag, scalar_value, global_step=None):
if self._writer is None:
# Do nothing if writer is not initialized
# Note that writer is only used when training is needed
pass
else:
self._writer.add_scalar(tag, scalar_value, global_step)
def train(self, start_epoch, max_epoch):
"""Generic training loops."""
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.before_epoch()
self.run_epoch()
self.after_epoch()
self.after_train()
def before_train(self):
pass
def after_train(self):
pass
def before_epoch(self):
pass
def after_epoch(self):
pass
def run_epoch(self):
raise NotImplementedError
def test(self):
raise NotImplementedError
def parse_batch_train(self, batch):
raise NotImplementedError
def parse_batch_test(self, batch):
raise NotImplementedError
def forward_backward(self, batch):
raise NotImplementedError
def model_inference(self, input):
raise NotImplementedError
def model_zero_grad(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].zero_grad()
def model_backward(self, loss):
self.detect_anomaly(loss)
loss.backward()
def model_update(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].step()
def model_backward_and_update(self, loss, names=None):
self.model_zero_grad(names)
self.model_backward(loss)
self.model_update(names)
def prograd_backward_and_update(
self, loss_a, loss_b, lambda_=1, names=None
):
# loss_b not increase is okay
# loss_a has to decline
self.model_zero_grad(names)
# get name of the model parameters
names = self.get_model_names(names)
# backward loss_a
self.detect_anomaly(loss_b)
loss_b.backward(retain_graph=True)
# normalize gradient
b_grads = []
for name in names:
for p in self._models[name].parameters():
b_grads.append(p.grad.clone())
# optimizer don't step
for name in names:
self._optims[name].zero_grad()
# backward loss_a
self.detect_anomaly(loss_a)
loss_a.backward()
for name in names:
for p, b_grad in zip(self._models[name].parameters(), b_grads):
# calculate cosine distance
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
a_grad = p.grad.clone()
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
p.grad = a_grad - lambda_ * torch.dot(
a_grad.flatten(), b_grad_norm.flatten()
) * b_grad_norm
# optimizer
for name in names:
self._optims[name].step()
class SimpleTrainer(TrainerBase):
"""A simple trainer class implementing generic functions."""
def __init__(self, cfg):
super().__init__()
self.check_cfg(cfg)
if torch.cuda.is_available() and cfg.USE_CUDA:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# Save as attributes some frequently used variables
self.start_epoch = self.epoch = 0
self.max_epoch = cfg.OPTIM.MAX_EPOCH
self.output_dir = cfg.OUTPUT_DIR
self.cfg = cfg
self.build_data_loader()
self.build_model()
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
self.best_result = -np.inf
def check_cfg(self, cfg):
"""Check whether some variables are set correctly for
the trainer (optional).
For example, a trainer might require a particular sampler
for training such as 'RandomDomainSampler', so it is good
to do the checking:
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
"""
pass
def build_data_loader(self):
"""Create essential data-related attributes.
A re-implementation of this method must create the
same attributes (except self.dm).
"""
dm = DataManager(self.cfg)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u # optional, can be None
self.val_loader = dm.val_loader # optional, can be None
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname # dict {label: classname}
self.dm = dm
def build_model(self):
"""Build and register model.
The default builds a classification model along with its
optimizer and scheduler.
Custom trainers can re-implement this method if necessary.
"""
cfg = self.cfg
print("Building model")
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
print("# params: {:,}".format(count_num_param(self.model)))
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("model", self.model, self.optim, self.sched)
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
)
self.model = nn.DataParallel(self.model)
def train(self):
super().train(self.start_epoch, self.max_epoch)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
self.start_epoch = self.resume_model_if_exist(directory)
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
def after_train(self):
print("Finished training")
do_test = not self.cfg.TEST.NO_TEST
if do_test:
if self.cfg.TEST.FINAL_MODEL == "best_val":
print("Deploy the model with the best val performance")
self.load_model(self.output_dir)
self.test()
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed: {}".format(elapsed))
# Close writer
self.close_writer()
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
model_name="model-best.pth.tar"
)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
@torch.no_grad()
def output_test(self, split=None):
"""testing pipline, which could also output the results."""
self.set_model_mode("eval")
self.evaluator.reset()
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
res_json = {}
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
img_path = batch['impath']
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
for i in range(len(img_path)):
res_json[img_path[i]] = {
'predict': output[i].cpu().numpy().tolist(),
'gt': label[i].cpu().numpy().tolist()
}
with open(output_file, 'w') as f:
json.dump(res_json, f)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
@torch.no_grad()
def test(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def model_inference(self, input):
return self.model(input)
def parse_batch_test(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]["lr"]
class TrainerXU(SimpleTrainer):
"""A base trainer using both labeled and unlabeled data.
In the context of domain adaptation, labeled and unlabeled data
come from source and target domains respectively.
When it comes to semi-supervised learning, all data comes from the
same domain.
"""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Decide to iterate over labeled or unlabeled dataset
len_train_loader_x = len(self.train_loader_x)
len_train_loader_u = len(self.train_loader_u)
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len_train_loader_x
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len_train_loader_u
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(len_train_loader_x, len_train_loader_u)
else:
raise ValueError
train_loader_x_iter = iter(self.train_loader_x)
train_loader_u_iter = iter(self.train_loader_u)
end = time.time()
for self.batch_idx in range(self.num_batches):
try:
batch_x = next(train_loader_x_iter)
except StopIteration:
train_loader_x_iter = iter(self.train_loader_x)
batch_x = next(train_loader_x_iter)
try:
batch_u = next(train_loader_u_iter)
except StopIteration:
train_loader_u_iter = iter(self.train_loader_u)
batch_u = next(train_loader_u_iter)
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch_x, batch_u)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, input_u
class TrainerX(SimpleTrainer):
"""A base trainer using labeled data only."""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
domain = batch["domain"]
input = input.to(self.device)
label = label.to(self.device)
domain = domain.to(self.device)
return input, label, domain
================================================
FILE: Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py
================================================
from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip
from .evaluator import EvaluatorBase, Classification
================================================
FILE: Dassl.ProGrad.pytorch/dassl/evaluation/build.py
================================================
from dassl.utils import Registry, check_availability
EVALUATOR_REGISTRY = Registry("EVALUATOR")
def build_evaluator(cfg, **kwargs):
avai_evaluators = EVALUATOR_REGISTRY.registered_names()
check_availability(cfg.TEST.EVALUATOR, avai_evaluators)
if cfg.VERBOSE:
print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR))
return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
================================================
import numpy as np
import os.path as osp
from collections import OrderedDict, defaultdict
import torch
from sklearn.metrics import f1_score, confusion_matrix
from .build import EVALUATOR_REGISTRY
class EvaluatorBase:
"""Base evaluator."""
def __init__(self, cfg):
self.cfg = cfg
def reset(self):
raise NotImplementedError
def process(self, mo, gt):
raise NotImplementedError
def evaluate(self):
raise NotImplementedError
@EVALUATOR_REGISTRY.register()
class Classification(EvaluatorBase):
"""Evaluator for classification."""
def __init__(self, cfg, lab2cname=None, **kwargs):
super().__init__(cfg)
self._lab2cname = lab2cname
self._correct = 0
self._total = 0
self._per_class_res = None
self._y_true = []
self._y_pred = []
if cfg.TEST.PER_CLASS_RESULT:
assert lab2cname is not None
self._per_class_res = defaultdict(list)
def reset(self):
self._correct = 0
self._total = 0
self._y_true = []
self._y_pred = []
if self._per_class_res is not None:
self._per_class_res = defaultdict(list)
def process(self, mo, gt):
# mo (torch.Tensor): model output [batch, num_classes]
# gt (torch.LongTensor): ground truth [batch]
pred = mo.max(1)[1]
matches = pred.eq(gt).float()
self._correct += int(matches.sum().item())
self._total += gt.shape[0]
self._y_true.extend(gt.data.cpu().numpy().tolist())
self._y_pred.extend(pred.data.cpu().numpy().tolist())
if self._per_class_res is not None:
for i, label in enumerate(gt):
label = label.item()
matches_i = int(matches[i].item())
self._per_class_res[label].append(matches_i)
def evaluate(self):
results = OrderedDict()
acc = 100.0 * self._correct / self._total
err = 100.0 - acc
macro_f1 = 100.0 * f1_score(
self._y_true,
self._y_pred,
average="macro",
labels=np.unique(self._y_true)
)
# The first value will be returned by trainer.test()
results["accuracy"] = acc
results["error_rate"] = err
results["macro_f1"] = macro_f1
print(
"=> result\n"
f"* total: {self._total:,}\n"
f"* correct: {self._correct:,}\n"
f"* accuracy: {acc:.2f}%\n"
f"* error: {err:.2f}%\n"
f"* macro_f1: {macro_f1:.2f}%"
)
if self._per_class_res is not None:
labels = list(self._per_class_res.keys())
labels.sort()
print("=> per-class result")
accs = []
for label in labels:
classname = self._lab2cname[label]
res = self._per_class_res[label]
correct = sum(res)
total = len(res)
acc = 100.0 * correct / total
accs.append(acc)
print(
"* class: {} ({})\t"
"total: {:,}\t"
"correct: {:,}\t"
"acc: {:.2f}%".format(
label, classname, total, correct, acc
)
)
mean_acc = np.mean(accs)
print("* average: {:.2f}%".format(mean_acc))
results["perclass_accuracy"] = mean_acc
if self.cfg.TEST.COMPUTE_CMAT:
cmat = confusion_matrix(
self._y_true, self._y_pred, normalize="true"
)
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
torch.save(cmat, save_path)
print('Confusion matrix is saved to "{}"'.format(save_path))
return results
================================================
FILE: Dassl.ProGrad.pytorch/dassl/metrics/__init__.py
================================================
from .accuracy import compute_accuracy
from .distance import (
cosine_distance, compute_distance_matrix, euclidean_squared_distance
)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
================================================
def compute_accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for
the specified values of k.
Args:
output (torch.Tensor): prediction matrix with shape (batch_size, num_classes).
target (torch.LongTensor): ground truth labels with shape (batch_size).
topk (tuple, optional): accuracy at top-k will be computed. For example,
topk=(1, 5) means accuracy at top-1 and top-5 will be computed.
Returns:
list: accuracy at top-k.
"""
maxk = max(topk)
batch_size = target.size(0)
if isinstance(output, (tuple, list)):
output = output[0]
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
acc = correct_k.mul_(100.0 / batch_size)
res.append(acc)
return res
================================================
FILE: Dassl.ProGrad.pytorch/dassl/metrics/distance.py
================================================
"""
Source: https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.nn import functional as F
def compute_distance_matrix(input1, input2, metric="euclidean"):
"""A wrapper function for computing distance matrix.
Each input matrix has the shape (n_data, feature_dim).
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
metric (str, optional): "euclidean" or "cosine".
Default is "euclidean".
Returns:
torch.Tensor: distance matrix.
"""
# check input
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
input1.dim()
)
assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
input2.dim()
)
assert input1.size(1) == input2.size(1)
if metric == "euclidean":
distmat = euclidean_squared_distance(input1, input2)
elif metric == "cosine":
distmat = cosine_distance(input1, input2)
else:
raise ValueError(
"Unknown distance metric: {}. "
'Please choose either "euclidean" or "cosine"'.format(metric)
)
return distmat
def euclidean_squared_distance(input1, input2):
"""Computes euclidean squared distance.
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
Returns:
torch.Tensor: distance matrix.
"""
m, n = input1.size(0), input2.size(0)
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat = mat1 + mat2
distmat.addmm_(1, -2, input1, input2.t())
return distmat
def cosine_distance(input1, input2):
"""Computes cosine distance.
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
Returns:
torch.Tensor: distance matrix.
"""
input1_normed = F.normalize(input1, p=2, dim=1)
input2_normed = F.normalize(input2, p=2, dim=1)
distmat = 1 - torch.mm(input1_normed, input2_normed.t())
return distmat
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/__init__.py
================================================
from .head import HEAD_REGISTRY, build_head
from .network import NETWORK_REGISTRY, build_network
from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py
================================================
from .build import build_backbone, BACKBONE_REGISTRY # isort:skip
from .backbone import Backbone # isort:skip
from .vgg import vgg16
from .resnet import (
resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,
resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,
resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,
resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,
resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,
resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123
)
from .alexnet import alexnet
from .mobilenetv2 import mobilenetv2
from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2
from .cnn_digitsdg import cnn_digitsdg
from .efficientnet import (
efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,
efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
)
from .shufflenetv2 import (
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,
shufflenet_v2_x2_0
)
from .cnn_digitsingle import cnn_digitsingle
from .preact_resnet18 import preact_resnet18
from .cnn_digit5_m3sda import cnn_digit5_m3sda
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
================================================
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
}
class AlexNet(Backbone):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
)
self._out_features = 4096
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
@BACKBONE_REGISTRY.register()
def alexnet(pretrained=True, **kwargs):
model = AlexNet()
if pretrained:
init_pretrained_weights(model, model_urls["alexnet"])
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
================================================
import torch.nn as nn
class Backbone(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
pass
@property
def out_features(self):
"""Output feature dimension."""
if self.__dict__.get("_out_features") is None:
return None
return self._out_features
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
================================================
from dassl.utils import Registry, check_availability
BACKBONE_REGISTRY = Registry("BACKBONE")
def build_backbone(name, verbose=True, **kwargs):
avai_backbones = BACKBONE_REGISTRY.registered_names()
check_availability(name, avai_backbones)
if verbose:
print("Backbone: {}".format(name))
return BACKBONE_REGISTRY.get(name)(**kwargs)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py
================================================
"""
Reference
https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA
"""
import torch.nn as nn
from torch.nn import functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class FeatureExtractor(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
self.bn3 = nn.BatchNorm2d(128)
self.fc1 = nn.Linear(8192, 3072)
self.bn1_fc = nn.BatchNorm1d(3072)
self.fc2 = nn.Linear(3072, 2048)
self.bn2_fc = nn.BatchNorm1d(2048)
self._out_features = 2048
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(x.size(0), 8192)
x = F.relu(self.bn1_fc(self.fc1(x)))
x = F.dropout(x, training=self.training)
x = F.relu(self.bn2_fc(self.fc2(x)))
return x
@BACKBONE_REGISTRY.register()
def cnn_digit5_m3sda(**kwargs):
"""
This architecture was used for the Digit-5 dataset in:
- Peng et al. Moment Matching for Multi-Source
Domain Adaptation. ICCV 2019.
"""
return FeatureExtractor()
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsdg.py
================================================
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class Convolution(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(Backbone):
def __init__(self, c_hidden=64):
super().__init__()
self.conv1 = Convolution(3, c_hidden)
self.conv2 = Convolution(c_hidden, c_hidden)
self.conv3 = Convolution(c_hidden, c_hidden)
self.conv4 = Convolution(c_hidden, c_hidden)
self._out_features = 2**2 * c_hidden
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
return x.view(x.size(0), -1)
@BACKBONE_REGISTRY.register()
def cnn_digitsdg(**kwargs):
"""
This architecture was used for DigitsDG dataset in:
- Zhou et al. Deep Domain-Adversarial Image Generation
for Domain Generalisation. AAAI 2020.
"""
model = ConvNet(c_hidden=64)
init_network_weights(model, init_type="kaiming")
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsingle.py
================================================
"""
This model is built based on
https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
"""
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class CNN(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 5)
self.conv2 = nn.Conv2d(64, 128, 5)
self.fc3 = nn.Linear(5 * 5 * 128, 1024)
self.fc4 = nn.Linear(1024, 1024)
self._out_features = 1024
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = self.fc3(x)
x = F.relu(x)
x = self.fc4(x)
x = F.relu(x)
return x
@BACKBONE_REGISTRY.register()
def cnn_digitsingle(**kwargs):
model = CNN()
init_network_weights(model, init_type="kaiming")
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/__init__.py
================================================
"""
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
"""
__version__ = "0.6.4"
from .model import (
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
efficientnet_b7
)
from .utils import (
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/model.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from .utils import (
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
get_model_params, efficientnet_params, get_same_padding_conv2d,
load_pretrained_weights, calculate_output_image_size
)
from ..build import BACKBONE_REGISTRY
from ..backbone import Backbone
class MBConvBlock(nn.Module):
"""
Mobile Inverted Residual Bottleneck Block
Args:
block_args (namedtuple): BlockArgs, see above
global_params (namedtuple): GlobalParam, see above
Attributes:
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio
is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # skip connection and drop connect
# Expansion phase
inp = self._block_args.input_filters # number of input channels
oup = (
self._block_args.input_filters * self._block_args.expand_ratio
) # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup,
out_channels=oup,
groups=oup, # groups makes it depthwise
kernel_size=k,
stride=s,
bias=False,
)
self._bn1 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(
1,
int(
self._block_args.input_filters * self._block_args.se_ratio
)
)
self._se_reduce = Conv2d(
in_channels=oup,
out_channels=num_squeezed_channels,
kernel_size=1
)
self._se_expand = Conv2d(
in_channels=num_squeezed_channels,
out_channels=oup,
kernel_size=1
)
# Output phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
)
self._bn2 = nn.BatchNorm2d(
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
)
self._swish = MemoryEfficientSwish()
def forward(self, inputs, drop_connect_rate=None):
"""
:param inputs: input tensor
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
:return: output of block
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed))
)
x = torch.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# Skip connection and drop connect
input_filters, output_filters = (
self._block_args.input_filters,
self._block_args.output_filters,
)
if (
self.id_skip and self._block_args.stride == 1
and input_filters == output_filters
):
if drop_connect_rate:
x = drop_connect(
x, p=drop_connect_rate, training=self.training
)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
class EfficientNet(Backbone):
"""
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
Args:
blocks_args (list): A list of BlockArgs to construct blocks
global_params (namedtuple): A set of GlobalParams shared between blocks
Example:
model = EfficientNet.from_pretrained('efficientnet-b0')
"""
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), "blocks_args should be a list"
assert len(blocks_args) > 0, "block args must be greater than 0"
self._global_params = global_params
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(
32, self._global_params
) # number of output channels
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(
block_args.input_filters, self._global_params
),
output_filters=round_filters(
block_args.output_filters, self._global_params
),
num_repeat=round_repeats(
block_args.num_repeat, self._global_params
),
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
image_size = calculate_output_image_size(
image_size, block_args.stride
)
if block_args.num_repeat > 1:
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1
)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = Conv2d(
in_channels, out_channels, kernel_size=1, bias=False
)
self._bn1 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
self._out_features = out_channels
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_features(self, inputs):
"""Returns output of the final convolution layer"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return x
def forward(self, inputs):
"""
Calls extract_features to extract features, applies
final linear layer, and returns logits.
"""
bs = inputs.size(0)
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.view(bs, -1)
x = self._dropout(x)
# x = self._fc(x)
return x
@classmethod
def from_name(cls, model_name, override_params=None):
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(
model_name, override_params
)
return cls(blocks_args, global_params)
@classmethod
def from_pretrained(
cls, model_name, advprop=False, num_classes=1000, in_channels=3
):
model = cls.from_name(
model_name, override_params={"num_classes": num_classes}
)
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
)
model._change_in_channels(in_channels)
return model
@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name."""
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError(
"model_name should be one of: " + ", ".join(valid_models)
)
def _change_in_channels(model, in_channels):
if in_channels != 3:
Conv2d = get_same_padding_conv2d(
image_size=model._global_params.image_size
)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
def build_efficientnet(name, pretrained):
if pretrained:
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
else:
return EfficientNet.from_name("efficientnet-{}".format(name))
@BACKBONE_REGISTRY.register()
def efficientnet_b0(pretrained=True, **kwargs):
return build_efficientnet("b0", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b1(pretrained=True, **kwargs):
return build_efficientnet("b1", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b2(pretrained=True, **kwargs):
return build_efficientnet("b2", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b3(pretrained=True, **kwargs):
return build_efficientnet("b3", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b4(pretrained=True, **kwargs):
return build_efficientnet("b4", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b5(pretrained=True, **kwargs):
return build_efficientnet("b5", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b6(pretrained=True, **kwargs):
return build_efficientnet("b6", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b7(pretrained=True, **kwargs):
return build_efficientnet("b7", pretrained)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/utils.py
================================================
"""
This file contains helper functions for building the model and for loading model parameters.
These helper functions are built to mirror those in the official TensorFlow implementation.
"""
import re
import math
import collections
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
########################################################################
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
########################################################################
# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = collections.namedtuple(
"GlobalParams",
[
"batch_norm_momentum",
"batch_norm_epsilon",
"dropout_rate",
"num_classes",
"width_coefficient",
"depth_coefficient",
"depth_divisor",
"min_depth",
"drop_connect_rate",
"image_size",
],
)
# Parameters for an individual model block
BlockArgs = collections.namedtuple(
"BlockArgs",
[
"kernel_size",
"num_repeat",
"input_filters",
"output_filters",
"expand_ratio",
"id_skip",
"stride",
"se_ratio",
],
)
# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def drop_connect(inputs, p, training):
"""Drop connect."""
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand(
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def get_same_padding_conv2d(image_size=None):
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models."""
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)
def get_width_and_height_from_size(x):
"""Obtains width and height from a int or tuple"""
if isinstance(x, int):
return x, x
if isinstance(x, list) or isinstance(x, tuple):
return x
else:
raise TypeError()
def calculate_output_image_size(input_image_size, stride):
"""
Calculates the output image size when using Conv2dSamePadding with a stride.
Necessary for static padding. Thanks to mannatsingh for pointing this out.
"""
if input_image_size is None:
return None
image_height, image_width = get_width_and_height_from_size(
input_image_size
)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
image_width = int(math.ceil(image_width / stride))
return [image_height, image_width]
class Conv2dDynamicSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a dynamic image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
):
super().__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias
)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x,
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
)
return F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class Conv2dStaticSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
image_size=None,
**kwargs
):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size,
image_size) if isinstance(image_size, int) else image_size
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d(
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
)
else:
self.static_padding = Identity()
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return x
class Identity(nn.Module):
def __init__(self, ):
super(Identity, self).__init__()
def forward(self, input):
return input
########################################################################
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
########################################################################
def efficientnet_params(model_name):
"""Map EfficientNet model name to parameter coefficients."""
params_dict = {
# Coefficients: width,depth,res,dropout
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
class BlockDecoder(object):
"""Block Decoder for readability, straight from the official TensorFlow repository"""
@staticmethod
def _decode_block_string(block_string):
"""Gets a block through a string notation of arguments."""
assert isinstance(block_string, str)
ops = block_string.split("_")
options = {}
for op in ops:
splits = re.split(r"(\d.*)", op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert ("s" in options and len(options["s"]) == 1) or (
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
)
return BlockArgs(
kernel_size=int(options["k"]),
num_repeat=int(options["r"]),
input_filters=int(options["i"]),
output_filters=int(options["o"]),
expand_ratio=int(options["e"]),
id_skip=("noskip" not in block_string),
se_ratio=float(options["se"]) if "se" in options else None,
stride=[int(options["s"][0])],
)
@staticmethod
def _encode_block_string(block):
"""Encodes a block to a string."""
args = [
"r%d" % block.num_repeat,
"k%d" % block.kernel_size,
"s%d%d" % (block.strides[0], block.strides[1]),
"e%s" % block.expand_ratio,
"i%d" % block.input_filters,
"o%d" % block.output_filters,
]
if 0 < block.se_ratio <= 1:
args.append("se%s" % block.se_ratio)
if block.id_skip is False:
args.append("noskip")
return "_".join(args)
@staticmethod
def decode(string_list):
"""
Decodes a list of string notations to specify blocks inside the network.
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
@staticmethod
def encode(blocks_args):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def efficientnet(
width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000,
):
"""Creates a efficientnet model."""
blocks_args = [
"r1_k3_s11_e1_i32_o16_se0.25",
"r2_k3_s22_e6_i16_o24_se0.25",
"r2_k5_s22_e6_i24_o40_se0.25",
"r3_k3_s22_e6_i40_o80_se0.25",
"r3_k5_s11_e6_i80_o112_se0.25",
"r4_k5_s22_e6_i112_o192_se0.25",
"r1_k3_s11_e6_i192_o320_se0.25",
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
# data_format='channels_last', # removed, this is always true in PyTorch
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size,
)
return blocks_args, global_params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model"""
if model_name.startswith("efficientnet"):
w, d, s, p = efficientnet_params(model_name)
# note: all models have drop connect rate = 0.2
blocks_args, global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s
)
else:
raise NotImplementedError(
"model name is not pre-defined: %s" % model_name
)
if override_params:
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)
return blocks_args, global_params
url_map = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
}
url_map_advprop = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
"efficientnet-b8":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
}
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
"""Loads pretrained weights, and downloads if loading for the first time."""
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
model.load_state_dict(state_dict, strict=False)
"""
if load_fc:
model.load_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))
"""
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
================================================
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"mobilenet_v2":
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
}
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor/2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1
):
padding = (kernel_size-1) // 2
super().__init__(
nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True),
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend(
[
# dw
ConvBNReLU(
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
]
)
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(Backbone):
def __init__(
self,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
):
"""
MobileNet V2.
Args:
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super().__init__()
if block is None:
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if (
len(inverted_residual_setting) == 0
or len(inverted_residual_setting[0]) != 4
):
raise ValueError(
"inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".
format(inverted_residual_setting)
)
# building first layer
input_channel = _make_divisible(
input_channel * width_mult, round_nearest
)
self.last_channel = _make_divisible(
last_channel * max(1.0, width_mult), round_nearest
)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(
block(
input_channel, output_channel, stride, expand_ratio=t
)
)
input_channel = output_channel
# building last several layers
features.append(
ConvBNReLU(input_channel, self.last_channel, kernel_size=1)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
self._out_features = self.last_channel
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
x = x.mean([2, 3])
return x
def forward(self, x):
return self._forward_impl(x)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def mobilenetv2(pretrained=True, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["mobilenet_v2"])
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
================================================
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class PreActBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class PreActBottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class PreActResNet(Backbone):
def __init__(self, block, num_blocks):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self._out_features = 512 * block.expansion
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
return out
"""
Preact-ResNet18 was used for the CIFAR10 and
SVHN datasets (both are SSL tasks) in
- Wang et al. Semi-Supervised Learning by
Augmented Distribution Alignment. ICCV 2019.
"""
@BACKBONE_REGISTRY.register()
def preact_resnet18(**kwargs):
return PreActResNet(PreActBlock, [2, 2, 2, 2])
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
================================================
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(Backbone):
def __init__(
self,
block,
layers,
ms_class=None,
ms_layers=[],
ms_p=0.5,
ms_a=0.1,
**kwargs
):
self.inplanes = 64
super().__init__()
# backbone network
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self._out_features = 512 * block.expansion
self.mixstyle = None
if ms_layers:
self.mixstyle = ms_class(p=ms_p, alpha=ms_a)
for layer_name in ms_layers:
assert layer_name in ["layer1", "layer2", "layer3"]
print(f"Insert MixStyle after {ms_layers}")
self.ms_layers = ms_layers
self._init_params()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
if "layer1" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer2(x)
if "layer2" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer3(x)
if "layer3" in self.ms_layers:
x = self.mixstyle(x)
return self.layer4(x)
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
"""
Residual network configurations:
--
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
"""
@BACKBONE_REGISTRY.register()
def resnet18(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet34(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet34"])
return model
@BACKBONE_REGISTRY.register()
def resnet50(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet152(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet152"])
return model
"""
Residual networks with mixstyle
"""
@BACKBONE_REGISTRY.register()
def resnet18_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
"""
Residual networks with efdmix
"""
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
================================================
"""
Code source: https://github.com/pytorch/vision
"""
import torch
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"shufflenetv2_x0.5":
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0":
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
}
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super().__init__()
if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(
i, o, kernel_size, stride, padding, bias=bias, groups=i
)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(Backbone):
def __init__(self, stages_repeats, stages_out_channels, **kwargs):
super().__init__()
if len(stages_repeats) != 3:
raise ValueError(
"expected stages_repeats as list of 3 positive ints"
)
if len(stages_out_channels) != 5:
raise ValueError(
"expected stages_out_channels as list of 5 positive ints"
)
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]
):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(
InvertedResidual(output_channels, output_channels, 1)
)
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self._out_features = output_channels
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x0_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x2_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"])
return model
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
================================================
import torch
import torch.nn as nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
class VGG(Backbone):
def __init__(self, features, init_weights=True):
super().__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
)
self._out_features = 4096
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B":
[64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
}
def _vgg(arch, cfg, batch_norm, pretrained):
init_weights = False if pretrained else True
model = VGG(
make_layers(cfgs[cfg], batch_norm=batch_norm),
init_weights=init_weights
)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
model.load_state_dict(state_dict, strict=False)
return model
@BACKBONE_REGISTRY.register()
def vgg16(pretrained=True, **kwargs):
return _vgg("vgg16", "D", False, pretrained)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
================================================
"""
Modified from https://github.com/xternalz/WideResNet-pytorch
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.LeakyReLU(0.01, inplace=True)
self.conv1 = nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.LeakyReLU(0.01, inplace=True)
self.conv2 = nn.Conv2d(
out_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
bias=False
)
self.droprate = dropRate
self.equalInOut = in_planes == out_planes
self.convShortcut = (
(not self.equalInOut) and nn.Conv2d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False,
) or None
)
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(
self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
):
super().__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, dropRate
)
def _make_layer(
self, block, in_planes, out_planes, nb_layers, stride, dropRate
):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(
i == 0 and in_planes or out_planes,
out_planes,
i == 0 and stride or 1,
dropRate,
)
)
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(Backbone):
def __init__(self, depth, widen_factor, dropRate=0.0):
super().__init__()
nChannels = [
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
]
assert (depth-4) % 6 == 0
n = (depth-4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(
3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False
)
# 1st block
self.block1 = NetworkBlock(
n, nChannels[0], nChannels[1], block, 1, dropRate
)
# 2nd block
self.block2 = NetworkBlock(
n, nChannels[1], nChannels[2], block, 2, dropRate
)
# 3rd block
self.block3 = NetworkBlock(
n, nChannels[2], nChannels[3], block, 2, dropRate
)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.LeakyReLU(0.01, inplace=True)
self._out_features = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.adaptive_avg_pool2d(out, 1)
return out.view(out.size(0), -1)
@BACKBONE_REGISTRY.register()
def wide_resnet_28_2(**kwargs):
return WideResNet(28, 2)
@BACKBONE_REGISTRY.register()
def wide_resnet_16_4(**kwargs):
return WideResNet(16, 4)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py
================================================
from .build import build_head, HEAD_REGISTRY # isort:skip
from .mlp import mlp
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
================================================
from dassl.utils import Registry, check_availability
HEAD_REGISTRY = Registry("HEAD")
def build_head(name, verbose=True, **kwargs):
avai_heads = HEAD_REGISTRY.registered_names()
check_availability(name, avai_heads)
if verbose:
print("Head: {}".format(name))
return HEAD_REGISTRY.get(name)(**kwargs)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
================================================
import functools
import torch.nn as nn
from .build import HEAD_REGISTRY
class MLP(nn.Module):
def __init__(
self,
in_features=2048,
hidden_layers=[],
activation="relu",
bn=True,
dropout=0.0,
):
super().__init__()
if isinstance(hidden_layers, int):
hidden_layers = [hidden_layers]
assert len(hidden_layers) > 0
self.out_features = hidden_layers[-1]
mlp = []
if activation == "relu":
act_fn = functools.partial(nn.ReLU, inplace=True)
elif activation == "leaky_relu":
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
else:
raise NotImplementedError
for hidden_dim in hidden_layers:
mlp += [nn.Linear(in_features, hidden_dim)]
if bn:
mlp += [nn.BatchNorm1d(hidden_dim)]
mlp += [act_fn()]
if dropout > 0:
mlp += [nn.Dropout(dropout)]
in_features = hidden_dim
self.mlp = nn.Sequential(*mlp)
def forward(self, x):
return self.mlp(x)
@HEAD_REGISTRY.register()
def mlp(**kwargs):
return MLP(**kwargs)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py
================================================
from .build import build_network, NETWORK_REGISTRY # isort:skip
from .ddaig_fcn import (
fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn
)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
================================================
from dassl.utils import Registry, check_availability
NETWORK_REGISTRY = Registry("NETWORK")
def build_network(name, verbose=True, **kwargs):
avai_models = NETWORK_REGISTRY.registered_names()
check_availability(name, avai_models)
if verbose:
print("Network: {}".format(name))
return NETWORK_REGISTRY.get(name)(**kwargs)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
================================================
"""
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
"""
import functools
import torch
import torch.nn as nn
from torch.nn import functional as F
from .build import NETWORK_REGISTRY
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm2d") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError(
"normalization layer [%s] is not found" % norm_type
)
return norm_layer
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super().__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout, use_bias
)
def build_conv_block(
self, dim, padding_type, norm_layer, use_dropout, use_bias
):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True),
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class LocNet(nn.Module):
"""Localization network."""
def __init__(
self,
input_nc,
nc=32,
n_blocks=3,
use_dropout=False,
padding_type="zero",
image_size=32,
):
super().__init__()
backbone = []
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
)
]
backbone += [nn.BatchNorm2d(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=nn.BatchNorm2d,
use_dropout=use_dropout,
use_bias=False,
)
]
backbone += [nn.MaxPool2d(2, stride=2)]
self.backbone = nn.Sequential(*backbone)
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
def forward(self, x):
x = self.backbone(x)
x = x.view(x.size(0), -1)
x = self.fc_loc(x)
x = torch.tanh(x)
x = x.view(-1, 2, 2)
theta = x.data.new_zeros(x.size(0), 2, 3)
theta[:, :, :2] = x
return theta
class FCN(nn.Module):
"""Fully convolutional network."""
def __init__(
self,
input_nc,
output_nc,
nc=32,
n_blocks=3,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
padding_type="reflect",
gctx=True,
stn=False,
image_size=32,
):
super().__init__()
backbone = []
p = 0
if padding_type == "reflect":
backbone += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
backbone += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
)
]
backbone += [norm_layer(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=False,
)
]
self.backbone = nn.Sequential(*backbone)
# global context fusion layer
self.gctx_fusion = None
if gctx:
self.gctx_fusion = nn.Sequential(
nn.Conv2d(
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
),
norm_layer(nc),
nn.ReLU(True),
)
self.regress = nn.Sequential(
nn.Conv2d(
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
),
nn.Tanh(),
)
self.locnet = None
if stn:
self.locnet = LocNet(
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
)
def init_loc_layer(self):
"""Initialize the weights/bias with identity transformation."""
if self.locnet is not None:
self.locnet.fc_loc.weight.data.zero_()
self.locnet.fc_loc.bias.data.copy_(
torch.tensor([1, 0, 0, 1], dtype=torch.float)
)
def stn(self, x):
"""Spatial transformer network."""
theta = self.locnet(x)
grid = F.affine_grid(theta, x.size())
return F.grid_sample(x, grid), theta
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
"""
Args:
x (torch.Tensor): input mini-batch.
lmda (float): multiplier for perturbation.
return_p (bool): return perturbation.
return_stn_output (bool): return the output of stn.
"""
theta = None
if self.locnet is not None:
x, theta = self.stn(x)
input = x
x = self.backbone(x)
if self.gctx_fusion is not None:
c = F.adaptive_avg_pool2d(x, (1, 1))
c = c.expand_as(x)
x = torch.cat([x, c], 1)
x = self.gctx_fusion(x)
p = self.regress(x)
x_p = input + lmda*p
if return_stn_output:
return x_p, p, input
if return_p:
return x_p, p
return x_p
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=32,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=64,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
================================================
from .mmd import MaximumMeanDiscrepancy
from .dsbn import DSBN1d, DSBN2d
from .mixup import mixup
from .efdmix import (
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
crossdomain_efdmix, run_without_efdmix
)
from .mixstyle import (
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
)
from .transnorm import TransNorm1d, TransNorm2d
from .sequential2 import Sequential2
from .reverse_grad import ReverseGrad
from .cross_entropy import cross_entropy
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
================================================
import torch
from torch.nn import functional as F
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
"""Cross entropy loss.
Args:
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
target (torch.LongTensor): int label matrix.
label_smooth (float, optional): label smoothing hyper-parameter.
Default is 0.
reduction (str, optional): how the losses for a mini-batch
will be aggregated. Default is 'mean'.
"""
num_classes = input.shape[1]
log_prob = F.log_softmax(input, dim=1)
zeros = torch.zeros(log_prob.size())
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
target = target.type_as(input)
target = (1-label_smooth) * target + label_smooth/num_classes
loss = (-target * log_prob).sum(1)
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
================================================
import torch.nn as nn
class _DSBN(nn.Module):
"""Domain Specific Batch Normalization.
Args:
num_features (int): number of features.
n_domain (int): number of domains.
bn_type (str): type of bn. Choices are ['1d', '2d'].
"""
def __init__(self, num_features, n_domain, bn_type):
super().__init__()
if bn_type == "1d":
BN = nn.BatchNorm1d
elif bn_type == "2d":
BN = nn.BatchNorm2d
else:
raise ValueError
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
self.valid_domain_idxs = list(range(n_domain))
self.n_domain = n_domain
self.domain_idx = 0
def select_bn(self, domain_idx=0):
assert domain_idx in self.valid_domain_idxs
self.domain_idx = domain_idx
def forward(self, x):
return self.bn[self.domain_idx](x)
class DSBN1d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "1d")
class DSBN2d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "2d")
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
================================================
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(False)
def activate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(True)
def random_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("random")
def crossdomain_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_efdmix(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_efdmix)
yield
finally:
model.apply(activate_efdmix)
@contextmanager
def run_with_efdmix(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_efdmix)
elif mix == "crossdomain":
model.apply(crossdomain_efdmix)
try:
model.apply(activate_efdmix)
yield
finally:
model.apply(deactivate_efdmix)
class EFDMix(nn.Module):
"""EFDMix.
Reference:
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
x_view = x.view(B, C, -1)
value_x, index_x = torch.sort(x_view) # sort inputs
lmda = self.beta.sample((B, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
inverse_index = index_x.argsort(-1)
x_view_copy = value_x[perm].gather(-1, inverse_index)
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
return new_x.view(B, C, W, H)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
================================================
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(False)
def activate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(True)
def random_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("random")
def crossdomain_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_mixstyle(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_mixstyle)
yield
finally:
model.apply(activate_mixstyle)
@contextmanager
def run_with_mixstyle(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_mixstyle)
elif mix == "crossdomain":
model.apply(crossdomain_mixstyle)
try:
model.apply(activate_mixstyle)
yield
finally:
model.apply(deactivate_mixstyle)
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
================================================
import torch
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
"""Mixup.
Args:
x1 (torch.Tensor): data with shape of (b, c, h, w).
x2 (torch.Tensor): data with shape of (b, c, h, w).
y1 (torch.Tensor): label with shape of (b, n).
y2 (torch.Tensor): label with shape of (b, n).
beta (float): hyper-parameter for Beta sampling.
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
Default is False.
"""
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
if preserve_order:
lmda = torch.max(lmda, 1 - lmda)
lmda = lmda.to(x1.device)
xmix = x1*lmda + x2 * (1-lmda)
lmda = lmda[:, :, 0, 0]
ymix = y1*lmda + y2 * (1-lmda)
return xmix, ymix
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
class MaximumMeanDiscrepancy(nn.Module):
def __init__(self, kernel_type="rbf", normalize=False):
super().__init__()
self.kernel_type = kernel_type
self.normalize = normalize
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
if self.normalize:
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
if self.kernel_type == "linear":
return self.linear_mmd(x, y)
elif self.kernel_type == "poly":
return self.poly_mmd(x, y)
elif self.kernel_type == "rbf":
return self.rbf_mmd(x, y)
else:
raise NotImplementedError
def linear_mmd(self, x, y):
# k(x, y) = x^T y
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_xy = torch.mm(x, y.t())
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
# k(x, y) = (alpha * x^T y + c)^d
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_xx = (alpha*k_xx + c).pow(d)
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_yy = (alpha*k_yy + c).pow(d)
k_xy = torch.mm(x, y.t())
k_xy = (alpha*k_xy + c).pow(d)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def rbf_mmd(self, x, y):
# k_xx
d_xx = self.euclidean_squared_distance(x, x)
d_xx = self.remove_self_distance(d_xx)
k_xx = self.rbf_kernel_mixture(d_xx)
# k_yy
d_yy = self.euclidean_squared_distance(y, y)
d_yy = self.remove_self_distance(d_yy)
k_yy = self.rbf_kernel_mixture(d_yy)
# k_xy
d_xy = self.euclidean_squared_distance(x, y)
k_xy = self.rbf_kernel_mixture(d_xy)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
@staticmethod
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
K = 0
for sigma in sigmas:
gamma = 1.0 / (2.0 * sigma**2)
K += torch.exp(-gamma * exponent)
return K
@staticmethod
def remove_self_distance(distmat):
tmp_list = []
for i, row in enumerate(distmat):
row1 = torch.cat([row[:i], row[i + 1:]])
tmp_list.append(row1)
return torch.stack(tmp_list)
@staticmethod
def euclidean_squared_distance(x, y):
m, n = x.size(0), y.size(0)
distmat = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
# distmat.addmm_(1, -2, x, y.t())
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
return distmat
if __name__ == "__main__":
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
d = mmd(input1, input2)
print(d.item())
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
class OptimalTransport(nn.Module):
@staticmethod
def distance(batch1, batch2, dist_metric="cosine"):
if dist_metric == "cosine":
batch1 = F.normalize(batch1, p=2, dim=1)
batch2 = F.normalize(batch2, p=2, dim=1)
dist_mat = 1 - torch.mm(batch1, batch2.t())
elif dist_metric == "euclidean":
m, n = batch1.size(0), batch2.size(0)
dist_mat = (
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist_mat.addmm_(
1, -2, batch1, batch2.t()
) # squared euclidean distance
elif dist_metric == "fast_euclidean":
batch1 = batch1.unsqueeze(-2)
batch2 = batch2.unsqueeze(-3)
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
else:
raise ValueError(
"Unknown cost function: {}. Expected to "
"be one of [cosine | euclidean]".format(dist_metric)
)
return dist_mat
class SinkhornDivergence(OptimalTransport):
thre = 1e-3
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__()
self.dist_metric = dist_metric
self.eps = eps
self.max_iter = max_iter
self.bp_to_sinkhorn = bp_to_sinkhorn
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
W_xy = self.transport_cost(x, y)
W_xx = self.transport_cost(x, x)
W_yy = self.transport_cost(y, y)
return 2*W_xy - W_xx - W_yy
def transport_cost(self, x, y, return_pi=False):
C = self.distance(x, y, dist_metric=self.dist_metric)
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
if not self.bp_to_sinkhorn:
pi = pi.detach()
cost = torch.sum(pi * C)
if return_pi:
return cost, pi
return cost
@staticmethod
def sinkhorn_iterate(C, eps, max_iter, thre):
nx, ny = C.shape
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
u = torch.zeros_like(mu)
v = torch.zeros_like(nu)
def M(_C, _u, _v):
"""Modified cost for logarithmic updates.
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
"""
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
real_iter = 0 # check if algorithm terminates before max_iter
# Sinkhorn iterations
for i in range(max_iter):
u0 = u
u = eps * (
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
) + u
v = (
eps * (
torch.log(nu + 1e-8) -
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
) + v
)
err = (u - u0).abs().sum()
real_iter += 1
if err.item() < thre:
break
# Transport plan pi = diag(a)*K*diag(b)
return torch.exp(M(C, u, v))
class MinibatchEnergyDistance(SinkhornDivergence):
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__(
dist_metric=dist_metric,
eps=eps,
max_iter=max_iter,
bp_to_sinkhorn=bp_to_sinkhorn,
)
def forward(self, x, y):
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
cost = 0
cost += self.transport_cost(x1, y1)
cost += self.transport_cost(x1, y2)
cost += self.transport_cost(x2, y1)
cost += self.transport_cost(x2, y2)
cost -= 2 * self.transport_cost(x1, x2)
cost -= 2 * self.transport_cost(y1, y2)
return cost
if __name__ == "__main__":
# example: https://dfdazac.github.io/sinkhorn.html
import numpy as np
n_points = 5
a = np.array([[i, 0] for i in range(n_points)])
b = np.array([[i, 1] for i in range(n_points)])
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDivergence(
dist_metric="euclidean", eps=0.01, max_iter=5
)
dist, pi = sinkhorn.transport_cost(x, y, True)
import pdb
pdb.set_trace()
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
================================================
import torch.nn as nn
from torch.autograd import Function
class _ReverseGrad(Function):
@staticmethod
def forward(ctx, input, grad_scaling):
ctx.grad_scaling = grad_scaling
return input.view_as(input)
@staticmethod
def backward(ctx, grad_output):
grad_scaling = ctx.grad_scaling
return -grad_scaling * grad_output, None
reverse_grad = _ReverseGrad.apply
class ReverseGrad(nn.Module):
"""Gradient reversal layer.
It acts as an identity layer in the forward,
but reverses the sign of the gradient in
the backward.
"""
def forward(self, x, grad_scaling=1.0):
assert (grad_scaling >=
0), "grad_scaling must be non-negative, " "but got {}".format(
grad_scaling
)
return reverse_grad(x, grad_scaling)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
================================================
import torch.nn as nn
class Sequential2(nn.Sequential):
"""An alternative sequential container to nn.Sequential,
which accepts an arbitrary number of input arguments.
"""
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
================================================
import torch
import torch.nn as nn
class _TransNorm(nn.Module):
"""Transferable normalization.
Reference:
- Wang et al. Transferable Normalization: Towards Improving
Transferability of Deep Neural Networks. NeurIPS 2019.
Args:
num_features (int): number of features.
eps (float): epsilon.
momentum (float): value for updating running_mean and running_var.
adaptive_alpha (bool): apply domain adaptive alpha.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.adaptive_alpha = adaptive_alpha
self.register_buffer("running_mean_s", torch.zeros(num_features))
self.register_buffer("running_var_s", torch.ones(num_features))
self.register_buffer("running_mean_t", torch.zeros(num_features))
self.register_buffer("running_var_t", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def resnet_running_stats(self):
self.running_mean_s.zero_()
self.running_var_s.fill_(1)
self.running_mean_t.zero_()
self.running_var_t.fill_(1)
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def _check_input(self, x):
raise NotImplementedError
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
C = self.num_features
ratio_s = mean_s / (var_s + self.eps).sqrt()
ratio_t = mean_t / (var_t + self.eps).sqrt()
dist = (ratio_s - ratio_t).abs()
dist_inv = 1 / (1+dist)
return C * dist_inv / dist_inv.sum()
def forward(self, input):
self._check_input(input)
C = self.num_features
if input.dim() == 2:
new_shape = (1, C)
elif input.dim() == 4:
new_shape = (1, C, 1, 1)
else:
raise ValueError
weight = self.weight.view(*new_shape)
bias = self.bias.view(*new_shape)
if not self.training:
mean_t = self.running_mean_t.view(*new_shape)
var_t = self.running_var_t.view(*new_shape)
output = (input-mean_t) / (var_t + self.eps).sqrt()
output = output*weight + bias
if self.adaptive_alpha:
mean_s = self.running_mean_s.view(*new_shape)
var_s = self.running_var_s.view(*new_shape)
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
x_s = input_s.transpose(0, 1).reshape(C, -1)
mean_s = x_s.mean(1)
var_s = x_s.var(1)
self.running_mean_s.mul_(self.momentum)
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
self.running_var_s.mul_(self.momentum)
self.running_var_s.add_((1 - self.momentum) * var_s.data)
mean_s = mean_s.reshape(*new_shape)
var_s = var_s.reshape(*new_shape)
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
output_s = output_s*weight + bias
x_t = input_t.transpose(0, 1).reshape(C, -1)
mean_t = x_t.mean(1)
var_t = x_t.var(1)
self.running_mean_t.mul_(self.momentum)
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
self.running_var_t.mul_(self.momentum)
self.running_var_t.add_((1 - self.momentum) * var_t.data)
mean_t = mean_t.reshape(*new_shape)
var_t = var_t.reshape(*new_shape)
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
output_t = output_t*weight + bias
output = torch.cat([output_s, output_t], 0)
if self.adaptive_alpha:
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
class TransNorm1d(_TransNorm):
def _check_input(self, x):
if x.dim() != 2:
raise ValueError(
"Expected the input to be 2-D, "
"but got {}-D".format(x.dim())
)
class TransNorm2d(_TransNorm):
def _check_input(self, x):
if x.dim() != 4:
raise ValueError(
"Expected the input to be 4-D, "
"but got {}-D".format(x.dim())
)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
================================================
import numpy as np
import torch
def sharpen_prob(p, temperature=2):
"""Sharpening probability with a temperature.
Args:
p (torch.Tensor): probability matrix (batch_size, n_classes)
temperature (float): temperature.
"""
p = p.pow(temperature)
return p / p.sum(1, keepdim=True)
def reverse_index(data, label):
"""Reverse order."""
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
return data[inv_idx], label[inv_idx]
def shuffle_index(data, label):
"""Shuffle order."""
rnd_idx = torch.randperm(data.shape[0])
return data[rnd_idx], label[rnd_idx]
def create_onehot(label, num_classes):
"""Create one-hot tensor.
We suggest using nn.functional.one_hot.
Args:
label (torch.Tensor): 1-D tensor.
num_classes (int): number of classes.
"""
onehot = torch.zeros(label.shape[0], num_classes)
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current/rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
ratio = np.clip(current / rampup_length, 0.0, 1.0)
return float(ratio)
def ema_model_update(model, ema_model, alpha):
"""Exponential moving average of model parameters.
Args:
model (nn.Module): model being trained.
ema_model (nn.Module): ema of the model.
alpha (float): ema decay rate.
"""
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/optim/__init__.py
================================================
from .optimizer import build_optimizer
from .lr_scheduler import build_lr_scheduler
================================================
FILE: Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py
================================================
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"]
class _BaseWarmupScheduler(_LRScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
last_epoch=-1,
verbose=False
):
self.successor = successor
self.warmup_epoch = warmup_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if self.last_epoch >= self.warmup_epoch:
self.successor.step(epoch)
self._last_lr = self.successor.get_last_lr()
else:
super().step(epoch)
class ConstantWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
cons_lr,
last_epoch=-1,
verbose=False
):
self.cons_lr = cons_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
return [self.cons_lr for _ in self.base_lrs]
class LinearWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
min_lr,
last_epoch=-1,
verbose=False
):
self.min_lr = min_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
if self.last_epoch == 0:
return [self.min_lr for _ in self.base_lrs]
return [
lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs
]
def build_lr_scheduler(optimizer, optim_cfg):
"""A function wrapper for building a learning rate scheduler.
Args:
optimizer (Optimizer): an Optimizer.
optim_cfg (CfgNode): optimization config.
"""
lr_scheduler = optim_cfg.LR_SCHEDULER
stepsize = optim_cfg.STEPSIZE
gamma = optim_cfg.GAMMA
max_epoch = optim_cfg.MAX_EPOCH
if lr_scheduler not in AVAI_SCHEDS:
raise ValueError(
"Unsupported scheduler: {}. Must be one of {}".format(
lr_scheduler, AVAI_SCHEDS
)
)
if lr_scheduler == "single_step":
if isinstance(stepsize, (list, tuple)):
stepsize = stepsize[-1]
if not isinstance(stepsize, int):
raise TypeError(
"For single_step lr_scheduler, stepsize must "
"be an integer, but got {}".format(type(stepsize))
)
if stepsize <= 0:
stepsize = max_epoch
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=stepsize, gamma=gamma
)
elif lr_scheduler == "multi_step":
if not isinstance(stepsize, (list, tuple)):
raise TypeError(
"For multi_step lr_scheduler, stepsize must "
"be a list, but got {}".format(type(stepsize))
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=stepsize, gamma=gamma
)
elif lr_scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(max_epoch)
)
if optim_cfg.WARMUP_EPOCH > 0:
if not optim_cfg.WARMUP_RECOUNT:
scheduler.last_epoch = optim_cfg.WARMUP_EPOCH
if optim_cfg.WARMUP_TYPE == "constant":
scheduler = ConstantWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_CONS_LR
)
elif optim_cfg.WARMUP_TYPE == "linear":
scheduler = LinearWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_MIN_LR
)
else:
raise ValueError
return scheduler
================================================
FILE: Dassl.ProGrad.pytorch/dassl/optim/optimizer.py
================================================
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import warnings
import torch
import torch.nn as nn
from .radam import RAdam
AVAI_OPTIMS = ["adam", "amsgrad", "sgd", "rmsprop", "radam", "adamw"]
def build_optimizer(model, optim_cfg):
"""A function wrapper for building an optimizer.
Args:
model (nn.Module or iterable): model.
optim_cfg (CfgNode): optimization config.
"""
optim = optim_cfg.NAME
lr = optim_cfg.LR
weight_decay = optim_cfg.WEIGHT_DECAY
momentum = optim_cfg.MOMENTUM
sgd_dampening = optim_cfg.SGD_DAMPNING
sgd_nesterov = optim_cfg.SGD_NESTEROV
rmsprop_alpha = optim_cfg.RMSPROP_ALPHA
adam_beta1 = optim_cfg.ADAM_BETA1
adam_beta2 = optim_cfg.ADAM_BETA2
staged_lr = optim_cfg.STAGED_LR
new_layers = optim_cfg.NEW_LAYERS
base_lr_mult = optim_cfg.BASE_LR_MULT
if optim not in AVAI_OPTIMS:
raise ValueError(
"Unsupported optim: {}. Must be one of {}".format(
optim, AVAI_OPTIMS
)
)
if staged_lr:
if not isinstance(model, nn.Module):
raise TypeError(
"When staged_lr is True, model given to "
"build_optimizer() must be an instance of nn.Module"
)
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(new_layers, str):
if new_layers is None:
warnings.warn(
"new_layers is empty, therefore, staged_lr is useless"
)
new_layers = [new_layers]
base_params = []
base_layers = []
new_params = []
for name, module in model.named_children():
if name in new_layers:
new_params += [p for p in module.parameters()]
else:
base_params += [p for p in module.parameters()]
base_layers.append(name)
param_groups = [
{
"params": base_params,
"lr": lr * base_lr_mult
},
{
"params": new_params
},
]
else:
if isinstance(model, nn.Module):
param_groups = model.parameters()
else:
param_groups = model
if optim == "adam":
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
elif optim == "amsgrad":
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
amsgrad=True,
)
elif optim == "sgd":
optimizer = torch.optim.SGD(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
dampening=sgd_dampening,
nesterov=sgd_nesterov,
)
elif optim == "rmsprop":
optimizer = torch.optim.RMSprop(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
alpha=rmsprop_alpha,
)
elif optim == "radam":
optimizer = RAdam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
elif optim == "adamw":
optimizer = torch.optim.AdamW(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
return optimizer
================================================
FILE: Dassl.ProGrad.pytorch/dassl/optim/radam.py
================================================
"""
Imported from: https://github.com/LiyuanLucasLiu/RAdam
https://arxiv.org/abs/1908.03265
@article{liu2019radam,
title={On the Variance of the Adaptive Learning Rate and Beyond},
author={Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},
journal={arXiv preprint arXiv:1908.03265},
year={2019}
}
"""
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
degenerated_to_sgd=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError(
"RAdam does not support sparse gradients"
)
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(
p_data_fp32
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state["step"] += 1
buffered = self.buffer[int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2**state["step"]
N_sma_max = 2 / (1-beta2) - 1
N_sma = N_sma_max - 2 * state["step"
] * beta2_t / (1-beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1-beta2_t) * (N_sma-4) / (N_sma_max-4) *
(N_sma-2) / N_sma * N_sma_max / (N_sma_max-2)
) / (1 - beta1**state["step"])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1**state["step"])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(
-step_size * group["lr"], exp_avg, denom
)
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
)
p_data_fp32.add_(-step_size * group["lr"], exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
degenerated_to_sgd=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError(
"RAdam does not support sparse gradients"
)
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(
p_data_fp32
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state["step"] += 1
beta2_t = beta2**state["step"]
N_sma_max = 2 / (1-beta2) - 1
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1-beta2_t)
# more conservative since it's an approximated value
if N_sma >= 5:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
)
step_size = (
group["lr"] * math.sqrt(
(1-beta2_t) * (N_sma-4) / (N_sma_max-4) *
(N_sma-2) / N_sma * N_sma_max / (N_sma_max-2)
) / (1 - beta1**state["step"])
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
elif self.degenerated_to_sgd:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
)
step_size = group["lr"] / (1 - beta1**state["step"])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class AdamW(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
warmup=0
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
warmup=warmup
)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(
p_data_fp32
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1**state["step"]
bias_correction2 = 1 - beta2**state["step"]
if group["warmup"] > state["step"]:
scheduled_lr = 1e-8 + state["step"] * group["lr"] / group[
"warmup"]
else:
scheduled_lr = group["lr"]
step_size = (
scheduled_lr * math.sqrt(bias_correction2) /
bias_correction1
)
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * scheduled_lr, p_data_fp32
)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/__init__.py
================================================
from .tools import *
from .logger import *
from .meters import *
from .registry import *
from .torchtools import *
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/logger.py
================================================
import os
import sys
import time
import os.path as osp
from .tools import mkdir_if_missing
__all__ = ["Logger", "setup_logger"]
class Logger:
"""Write console output to external text file.
Imported from ``_
Args:
fpath (str): directory to save logging file.
Examples::
>>> import sys
>>> import os.path as osp
>>> save_dir = 'output/experiment-1'
>>> log_name = 'train.log'
>>> sys.stdout = Logger(osp.join(save_dir, log_name))
"""
def __init__(self, fpath=None):
self.console = sys.stdout
self.file = None
if fpath is not None:
mkdir_if_missing(osp.dirname(fpath))
self.file = open(fpath, "w")
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
self.console.close()
if self.file is not None:
self.file.close()
def setup_logger(output=None):
if output is None:
return
if output.endswith(".txt") or output.endswith(".log"):
fpath = output
else:
fpath = osp.join(output, "log.txt")
if osp.exists(fpath):
# make sure the existing log file is not over-written
fpath += time.strftime("-%Y-%m-%d-%H-%M-%S")
sys.stdout = Logger(fpath)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/meters.py
================================================
from collections import defaultdict
import torch
__all__ = ["AverageMeter", "MetricMeter"]
class AverageMeter:
"""Compute and store the average and current value.
Examples::
>>> # 1. Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # 2. Update meter after every mini-batch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self, ema=False):
"""
Args:
ema (bool, optional): apply exponential moving average.
"""
self.ema = ema
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if isinstance(val, torch.Tensor):
val = val.item()
self.val = val
self.sum += val * n
self.count += n
if self.ema:
self.avg = self.avg * 0.9 + self.val * 0.1
else:
self.avg = self.sum / self.count
class MetricMeter:
"""Store the average and current value for a set of metrics.
Examples::
>>> # 1. Create an instance of MetricMeter
>>> metric = MetricMeter()
>>> # 2. Update using a dictionary as input
>>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
>>> metric.update(input_dict)
>>> # 3. Convert to string and print
>>> print(str(metric))
"""
def __init__(self, delimiter="\t"):
self.meters = defaultdict(AverageMeter)
self.delimiter = delimiter
def update(self, input_dict):
if input_dict is None:
return
if not isinstance(input_dict, dict):
raise TypeError(
"Input to MetricMeter.update() must be a dictionary"
)
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
self.meters[k].update(v)
def __str__(self):
output_str = []
for name, meter in self.meters.items():
output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})")
return self.delimiter.join(output_str)
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/registry.py
================================================
"""
Modified from https://github.com/facebookresearch/fvcore
"""
__all__ = ["Registry"]
class Registry:
"""A registry providing name -> object mapping, to support
custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone(nn.Module):
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
self._name = name
self._obj_map = dict()
def _do_register(self, name, obj, force=False):
if name in self._obj_map and not force:
raise KeyError(
'An object named "{}" was already '
'registered in "{}" registry'.format(name, self._name)
)
self._obj_map[name] = obj
def register(self, obj=None, force=False):
if obj is None:
# Used as a decorator
def wrapper(fn_or_class):
name = fn_or_class.__name__
self._do_register(name, fn_or_class, force=force)
return fn_or_class
return wrapper
# Used as a function call
name = obj.__name__
self._do_register(name, obj, force=force)
def get(self, name):
if name not in self._obj_map:
raise KeyError(
'Object name "{}" does not exist '
'in "{}" registry'.format(name, self._name)
)
return self._obj_map[name]
def registered_names(self):
return list(self._obj_map.keys())
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/tools.py
================================================
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import os
import sys
import json
import time
import errno
import numpy as np
import random
import os.path as osp
import warnings
from difflib import SequenceMatcher
import PIL
import torch
from PIL import Image
__all__ = [
"mkdir_if_missing",
"check_isfile",
"read_json",
"write_json",
"set_random_seed",
"download_url",
"read_image",
"collect_env_info",
"listdir_nohidden",
"get_most_similar_str_to_a_from_b",
"check_availability",
"tolist_if_not",
]
def mkdir_if_missing(dirname):
"""Create dirname if it is missing."""
if not osp.exists(dirname):
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(fpath):
"""Check if the given path is a file.
Args:
fpath (str): file path.
Returns:
bool
"""
isfile = osp.isfile(fpath)
if not isfile:
warnings.warn('No file found at "{}"'.format(fpath))
return isfile
def read_json(fpath):
"""Read json file from a path."""
with open(fpath, "r") as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
"""Writes to a json file."""
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, "w") as f:
json.dump(obj, f, indent=4, separators=(",", ": "))
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def download_url(url, dst):
"""Download file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
from six.moves import urllib
print('* url="{}"'.format(url))
print('* destination="{}"'.format(dst))
def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024*duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write(
"\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024*1024), speed, duration)
)
sys.stdout.flush()
urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write("\n")
def read_image(path):
"""Read image from path using ``PIL.Image``.
Args:
path (str): path to an image.
Returns:
PIL image
"""
if not osp.exists(path):
raise IOError("No file exists at {}".format(path))
while True:
try:
img = Image.open(path).convert("RGB")
return img
except IOError:
print(
"Cannot read image from {}, "
"probably due to heavy IO. Will re-try".format(path)
)
def collect_env_info():
"""Return env info as a string.
Code source: github.com/facebookresearch/maskrcnn-benchmark
"""
from torch.utils.collect_env import get_pretty_env_info
env_str = get_pretty_env_info()
env_str += "\n Pillow ({})".format(PIL.__version__)
return env_str
def listdir_nohidden(path, sort=False):
"""List non-hidden items in a directory.
Args:
path (str): directory path.
sort (bool): sort the items.
"""
items = [f for f in os.listdir(path) if not f.startswith(".")]
if sort:
items.sort()
return items
def get_most_similar_str_to_a_from_b(a, b):
"""Return the most similar string to a in b.
Args:
a (str): probe string.
b (list): a list of candidate strings.
"""
highest_sim = 0
chosen = None
for candidate in b:
sim = SequenceMatcher(None, a, candidate).ratio()
if sim >= highest_sim:
highest_sim = sim
chosen = candidate
return chosen
def check_availability(requested, available):
"""Check if an element is available in a list.
Args:
requested (str): probe string.
available (list): a list of available strings.
"""
if requested not in available:
psb_ans = get_most_similar_str_to_a_from_b(requested, available)
raise ValueError(
"The requested one is expected "
"to belong to {}, but got [{}] "
"(do you mean [{}]?)".format(available, requested, psb_ans)
)
def tolist_if_not(x):
"""Convert to a list."""
if not isinstance(x, list):
x = [x]
return x
================================================
FILE: Dassl.ProGrad.pytorch/dassl/utils/torchtools.py
================================================
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import pickle
import shutil
import os.path as osp
import warnings
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
from .tools import mkdir_if_missing
__all__ = [
"save_checkpoint",
"load_checkpoint",
"resume_from_checkpoint",
"open_all_layers",
"open_specified_layers",
"count_num_param",
"load_pretrained_weights",
"init_network_weights",
]
def save_checkpoint(
state,
save_dir,
is_best=False,
remove_module_from_keys=True,
model_name=""
):
r"""Save checkpoint.
Args:
state (dict): dictionary.
save_dir (str): directory to save checkpoint.
is_best (bool, optional): if True, this checkpoint will be copied and named
``model-best.pth.tar``. Default is False.
remove_module_from_keys (bool, optional): whether to remove "module."
from layer names. Default is True.
model_name (str, optional): model name to save.
Examples::
>>> state = {
>>> 'state_dict': model.state_dict(),
>>> 'epoch': 10,
>>> 'optimizer': optimizer.state_dict()
>>> }
>>> save_checkpoint(state, 'log/my_model')
"""
mkdir_if_missing(save_dir)
if remove_module_from_keys:
# remove 'module.' in state_dict's keys
state_dict = state["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
new_state_dict[k] = v
state["state_dict"] = new_state_dict
# save model
epoch = state["epoch"]
if not model_name:
model_name = "model.pth.tar-" + str(epoch)
fpath = osp.join(save_dir, model_name)
torch.save(state, fpath)
print('Checkpoint saved to "{}"'.format(fpath))
# save current model name
checkpoint_file = osp.join(save_dir, "checkpoint")
checkpoint = open(checkpoint_file, "w+")
checkpoint.write("{}\n".format(osp.basename(fpath)))
checkpoint.close()
if is_best:
best_fpath = osp.join(osp.dirname(fpath), "model-best.pth.tar")
shutil.copy(fpath, best_fpath)
print('Best checkpoint saved to "{}"'.format(best_fpath))
def load_checkpoint(fpath):
r"""Load checkpoint.
``UnicodeDecodeError`` can be well handled, which means
python2-saved files can be read from python3.
Args:
fpath (str): path to checkpoint.
Returns:
dict
Examples::
>>> fpath = 'log/my_model/model.pth.tar-10'
>>> checkpoint = load_checkpoint(fpath)
"""
if fpath is None:
raise ValueError("File path is None")
if not osp.exists(fpath):
raise FileNotFoundError('File is not found at "{}"'.format(fpath))
map_location = None if torch.cuda.is_available() else "cpu"
try:
checkpoint = torch.load(fpath, map_location=map_location)
except UnicodeDecodeError:
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
checkpoint = torch.load(
fpath, pickle_module=pickle, map_location=map_location
)
except Exception:
print('Unable to load checkpoint from "{}"'.format(fpath))
raise
return checkpoint
def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None):
r"""Resume training from a checkpoint.
This will load (1) model weights and (2) ``state_dict``
of optimizer if ``optimizer`` is not None.
Args:
fdir (str): directory where the model was saved.
model (nn.Module): model.
optimizer (Optimizer, optional): an Optimizer.
scheduler (Scheduler, optional): an Scheduler.
Returns:
int: start_epoch.
Examples::
>>> fdir = 'log/my_model'
>>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler)
"""
with open(osp.join(fdir, "checkpoint"), "r") as checkpoint:
model_name = checkpoint.readlines()[0].strip("\n")
fpath = osp.join(fdir, model_name)
print('Loading checkpoint from "{}"'.format(fpath))
checkpoint = load_checkpoint(fpath)
model.load_state_dict(checkpoint["state_dict"])
print("Loaded model weights")
if optimizer is not None and "optimizer" in checkpoint.keys():
optimizer.load_state_dict(checkpoint["optimizer"])
print("Loaded optimizer")
if scheduler is not None and "scheduler" in checkpoint.keys():
scheduler.load_state_dict(checkpoint["scheduler"])
print("Loaded scheduler")
start_epoch = checkpoint["epoch"]
print("Previous epoch: {}".format(start_epoch))
return start_epoch
def adjust_learning_rate(
optimizer,
base_lr,
epoch,
stepsize=20,
gamma=0.1,
linear_decay=False,
final_lr=0,
max_epoch=100,
):
r"""Adjust learning rate.
Deprecated.
"""
if linear_decay:
# linearly decay learning rate from base_lr to final_lr
frac_done = epoch / max_epoch
lr = frac_done*final_lr + (1.0-frac_done) * base_lr
else:
# decay learning rate by gamma for every stepsize
lr = base_lr * (gamma**(epoch // stepsize))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def set_bn_to_eval(m):
r"""Set BatchNorm layers to eval mode."""
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.eval()
def open_all_layers(model):
r"""Open all layers in model for training.
Examples::
>>> open_all_layers(model)
"""
model.train()
for p in model.parameters():
p.requires_grad = True
def open_specified_layers(model, open_layers):
r"""Open specified layers in model for training while keeping
other layers frozen.
Args:
model (nn.Module): neural net model.
open_layers (str or list): layers open for training.
Examples::
>>> # Only model.classifier will be updated.
>>> open_layers = 'classifier'
>>> open_specified_layers(model, open_layers)
>>> # Only model.fc and model.classifier will be updated.
>>> open_layers = ['fc', 'classifier']
>>> open_specified_layers(model, open_layers)
"""
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(open_layers, str):
open_layers = [open_layers]
for layer in open_layers:
assert hasattr(
model, layer
), '"{}" is not an attribute of the model, please provide the correct name'.format(
layer
)
for name, module in model.named_children():
if name in open_layers:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False
def count_num_param(model):
r"""Count number of parameters in a model.
Args:
model (nn.Module): network model.
Examples::
>>> model_size = count_num_param(model)
"""
return sum(p.numel() for p in model.parameters())
def load_pretrained_weights(model, weight_path):
r"""Load pretrianed weights to model.
Features::
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing "module.".
Args:
model (nn.Module): network model.
weight_path (str): path to pretrained weights.
Examples::
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)
"""
checkpoint = load_checkpoint(weight_path)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights "{}" cannot be loaded, '
"please check the key names manually "
"(** ignored and continue **)".format(weight_path)
)
else:
print(
'Successfully loaded pretrained weights from "{}"'.
format(weight_path)
)
if len(discarded_layers) > 0:
print(
"** The following layers are discarded "
"due to unmatched keys or layer size: {}".
format(discarded_layers)
)
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)
================================================
FILE: Dassl.ProGrad.pytorch/datasets/da/cifar_stl.py
================================================
import sys
import pprint as pp
import os.path as osp
from torchvision.datasets import STL10, CIFAR10
from dassl.utils import mkdir_if_missing
cifar_label2name = {
0: "airplane",
1: "car", # the original name was 'automobile'
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog", # conflict class
7: "horse",
8: "ship",
9: "truck",
}
stl_label2name = {
0: "airplane",
1: "bird",
2: "car",
3: "cat",
4: "deer",
5: "dog",
6: "horse",
7: "monkey", # conflict class
8: "ship",
9: "truck",
}
new_name2label = {
"airplane": 0,
"bird": 1,
"car": 2,
"cat": 3,
"deer": 4,
"dog": 5,
"horse": 6,
"ship": 7,
"truck": 8,
}
def extract_and_save_image(dataset, save_dir, discard, label2name):
if osp.exists(save_dir):
print('Folder "{}" already exists'.format(save_dir))
return
print('Extracting images to "{}" ...'.format(save_dir))
mkdir_if_missing(save_dir)
for i in range(len(dataset)):
img, label = dataset[i]
if label == discard:
continue
class_name = label2name[label]
label_new = new_name2label[class_name]
class_dir = osp.join(
save_dir,
str(label_new).zfill(3) + "_" + class_name
)
mkdir_if_missing(class_dir)
impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg")
img.save(impath)
def download_and_prepare(name, root, discarded_label, label2name):
print("Dataset: {}".format(name))
print("Root: {}".format(root))
print("Old labels:")
pp.pprint(label2name)
print("Discarded label: {}".format(discarded_label))
print("New labels:")
pp.pprint(new_name2label)
if name == "cifar":
train = CIFAR10(root, train=True, download=True)
test = CIFAR10(root, train=False)
else:
train = STL10(root, split="train", download=True)
test = STL10(root, split="test")
train_dir = osp.join(root, name, "train")
test_dir = osp.join(root, name, "test")
extract_and_save_image(train, train_dir, discarded_label, label2name)
extract_and_save_image(test, test_dir, discarded_label, label2name)
if __name__ == "__main__":
download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name)
download_and_prepare("stl", sys.argv[1], 7, stl_label2name)
================================================
FILE: Dassl.ProGrad.pytorch/datasets/da/digit5.py
================================================
import os
import numpy as np
import os.path as osp
import argparse
from PIL import Image
from scipy.io import loadmat
def mkdir_if_missing(directory):
if not osp.exists(directory):
os.makedirs(directory)
def extract_and_save(data, label, save_dir):
for i, (x, y) in enumerate(zip(data, label)):
if x.shape[2] == 1:
x = np.repeat(x, 3, axis=2)
if y == 10:
y = 0
x = Image.fromarray(x, mode="RGB")
save_path = osp.join(
save_dir,
str(i + 1).zfill(6) + "_" + str(y) + ".jpg"
)
x.save(save_path)
def load_mnist(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "mnist_data.mat")
data = loadmat(filepath)
train_data = np.reshape(data["train_32"], (55000, 32, 32, 1))
test_data = np.reshape(data["test_32"], (10000, 32, 32, 1))
train_label = np.nonzero(data["label_train"])[1]
test_label = np.nonzero(data["label_test"])[1]
return train_data, test_data, train_label, test_label
def load_mnist_m(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "mnistm_with_label.mat")
data = loadmat(filepath)
train_data = data["train"]
test_data = data["test"]
train_label = np.nonzero(data["label_train"])[1]
test_label = np.nonzero(data["label_test"])[1]
return train_data, test_data, train_label, test_label
def load_svhn(data_dir, raw_data_dir):
train = loadmat(osp.join(raw_data_dir, "svhn_train_32x32.mat"))
train_data = train["X"].transpose(3, 0, 1, 2)
train_label = train["y"][:, 0]
test = loadmat(osp.join(raw_data_dir, "svhn_test_32x32.mat"))
test_data = test["X"].transpose(3, 0, 1, 2)
test_label = test["y"][:, 0]
return train_data, test_data, train_label, test_label
def load_syn(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "syn_number.mat")
data = loadmat(filepath)
train_data = data["train_data"]
test_data = data["test_data"]
train_label = data["train_label"][:, 0]
test_label = data["test_label"][:, 0]
return train_data, test_data, train_label, test_label
def load_usps(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "usps_28x28.mat")
data = loadmat(filepath)["dataset"]
train_data = data[0][0].transpose(0, 2, 3, 1)
test_data = data[1][0].transpose(0, 2, 3, 1)
train_data *= 255
test_data *= 255
train_data = train_data.astype(np.uint8)
test_data = test_data.astype(np.uint8)
train_label = data[0][1][:, 0]
test_label = data[1][1][:, 0]
return train_data, test_data, train_label, test_label
def main(data_dir):
data_dir = osp.abspath(osp.expanduser(data_dir))
raw_data_dir = osp.join(data_dir, "Digit-Five")
if not osp.exists(data_dir):
raise FileNotFoundError('"{}" does not exist'.format(data_dir))
datasets = ["mnist", "mnist_m", "svhn", "syn", "usps"]
for name in datasets:
print("Creating {}".format(name))
output = eval("load_" + name)(data_dir, raw_data_dir)
train_data, test_data, train_label, test_label = output
print("# train: {}".format(train_data.shape[0]))
print("# test: {}".format(test_data.shape[0]))
train_dir = osp.join(data_dir, name, "train_images")
mkdir_if_missing(train_dir)
test_dir = osp.join(data_dir, name, "test_images")
mkdir_if_missing(test_dir)
extract_and_save(train_data, train_label, train_dir)
extract_and_save(test_data, test_label, test_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"data_dir", type=str, help="directory containing Digit-Five/"
)
args = parser.parse_args()
main(args.data_dir)
================================================
FILE: Dassl.ProGrad.pytorch/datasets/da/visda17.sh
================================================
# ------------------------------------------------------------------------
# ROOT is the root directory where you put your domain datasets.
#
# Suppose you wanna put the dataset under $DATA, which stores all the
# domain datasets, run the following command in your terminal to
# download VisDa17:
#
# $ sh visda17.sh $DATA
#------------------------------------------------------------------------
ROOT=$1
mkdir $ROOT/visda17
cd $ROOT/visda17
wget http://csr.bu.edu/ftp/visda17/clf/train.tar
tar xvf train.tar
wget http://csr.bu.edu/ftp/visda17/clf/validation.tar
tar xvf validation.tar
wget http://csr.bu.edu/ftp/visda17/clf/test.tar
tar xvf test.tar
wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt
================================================
FILE: Dassl.ProGrad.pytorch/datasets/dg/cifar_c.py
================================================
"""
This script
- creates a folder named "cifar10_c" under the same directory as 'CIFAR-10-C'
- extracts images from .npy files and save them as .jpg.
"""
import os
import sys
import numpy as np
import os.path as osp
from PIL import Image
from dassl.utils import mkdir_if_missing
def extract_and_save(images, labels, level, dst):
# level denotes the corruption intensity level (0-based)
assert 0 <= level <= 4
for i in range(10000):
real_i = i + level*10000
im = Image.fromarray(images[real_i])
label = int(labels[real_i])
category_dir = osp.join(dst, str(label).zfill(3))
mkdir_if_missing(category_dir)
save_path = osp.join(category_dir, str(i + 1).zfill(5) + ".jpg")
im.save(save_path)
def main(npy_folder):
npy_folder = osp.abspath(osp.expanduser(npy_folder))
dataset_cap = osp.basename(npy_folder)
assert dataset_cap in ["CIFAR-10-C", "CIFAR-100-C"]
if dataset_cap == "CIFAR-10-C":
dataset = "cifar10_c"
else:
dataset = "cifar100_c"
if not osp.exists(npy_folder):
print('The given folder "{}" does not exist'.format(npy_folder))
root = osp.dirname(npy_folder)
im_folder = osp.join(root, dataset)
mkdir_if_missing(im_folder)
dirnames = os.listdir(npy_folder)
dirnames.remove("labels.npy")
if "README.txt" in dirnames:
dirnames.remove("README.txt")
assert len(dirnames) == 19
labels = np.load(osp.join(npy_folder, "labels.npy"))
for dirname in dirnames:
corruption = dirname.split(".")[0]
corruption_folder = osp.join(im_folder, corruption)
mkdir_if_missing(corruption_folder)
npy_filename = osp.join(npy_folder, dirname)
images = np.load(npy_filename)
assert images.shape[0] == 50000
for level in range(5):
dst = osp.join(corruption_folder, str(level + 1))
mkdir_if_missing(dst)
print('Saving images to "{}"'.format(dst))
extract_and_save(images, labels, level, dst)
if __name__ == "__main__":
# sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C
main(sys.argv[1])
================================================
FILE: Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py
================================================
import sys
import os.path as osp
from torchvision.datasets import SVHN, CIFAR10, CIFAR100
from dassl.utils import mkdir_if_missing
def extract_and_save_image(dataset, save_dir):
if osp.exists(save_dir):
print('Folder "{}" already exists'.format(save_dir))
return
print('Extracting images to "{}" ...'.format(save_dir))
mkdir_if_missing(save_dir)
for i in range(len(dataset)):
img, label = dataset[i]
class_dir = osp.join(save_dir, str(label).zfill(3))
mkdir_if_missing(class_dir)
impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg")
img.save(impath)
def download_and_prepare(name, root):
print("Dataset: {}".format(name))
print("Root: {}".format(root))
if name == "cifar10":
train = CIFAR10(root, train=True, download=True)
test = CIFAR10(root, train=False)
elif name == "cifar100":
train = CIFAR100(root, train=True, download=True)
test = CIFAR100(root, train=False)
elif name == "svhn":
train = SVHN(root, split="train", download=True)
test = SVHN(root, split="test", download=True)
else:
raise ValueError
train_dir = osp.join(root, name, "train")
test_dir = osp.join(root, name, "test")
extract_and_save_image(train, train_dir)
extract_and_save_image(test, test_dir)
if __name__ == "__main__":
download_and_prepare("cifar10", sys.argv[1])
download_and_prepare("cifar100", sys.argv[1])
download_and_prepare("svhn", sys.argv[1])
================================================
FILE: Dassl.ProGrad.pytorch/datasets/ssl/stl10.py
================================================
import sys
import os.path as osp
from torchvision.datasets import STL10
from dassl.utils import mkdir_if_missing
def extract_and_save_image(dataset, save_dir):
if osp.exists(save_dir):
print('Folder "{}" already exists'.format(save_dir))
return
print('Extracting images to "{}" ...'.format(save_dir))
mkdir_if_missing(save_dir)
for i in range(len(dataset)):
img, label = dataset[i]
if label == -1:
label_name = "none"
else:
label_name = str(label)
imname = str(i).zfill(6) + "_" + label_name + ".jpg"
impath = osp.join(save_dir, imname)
img.save(impath)
def download_and_prepare(root):
train = STL10(root, split="train", download=True)
test = STL10(root, split="test")
unlabeled = STL10(root, split="unlabeled")
train_dir = osp.join(root, "train")
test_dir = osp.join(root, "test")
unlabeled_dir = osp.join(root, "unlabeled")
extract_and_save_image(train, train_dir)
extract_and_save_image(test, test_dir)
extract_and_save_image(unlabeled, unlabeled_dir)
if __name__ == "__main__":
download_and_prepare(sys.argv[1])
================================================
FILE: Dassl.ProGrad.pytorch/linter.sh
================================================
echo "Running isort"
isort -y -sp .
echo "Done"
echo "Running yapf"
yapf -i -r -vv -e build .
echo "Done"
echo "Running flake8"
flake8 .
echo "Done"
================================================
FILE: Dassl.ProGrad.pytorch/requirements.txt
================================================
flake8==3.7.9
yapf==0.29.0
isort==4.3.21
yacs
gdown
tb-nightly
future
scipy
scikit-learn
tqdm
================================================
FILE: Dassl.ProGrad.pytorch/setup.py
================================================
import numpy as np
import os.path as osp
from setuptools import setup, find_packages
def readme():
with open('README.md') as f:
content = f.read()
return content
def find_version():
version_file = 'dassl/__init__.py'
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def numpy_include():
try:
numpy_include = np.get_include()
except AttributeError:
numpy_include = np.get_numpy_include()
return numpy_include
def get_requirements(filename='requirements.txt'):
here = osp.dirname(osp.realpath(__file__))
with open(osp.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
setup(
name='dassl',
version=find_version(),
description='Dassl: Domain adaptation and semi-supervised learning',
author='Kaiyang Zhou',
license='MIT',
long_description=readme(),
url='https://github.com/KaiyangZhou/Dassl.pytorch',
packages=find_packages(),
install_requires=get_requirements(),
keywords=[
'Domain Adaptation', 'Domain Generalization',
'Semi-Supervised Learning', 'Pytorch'
]
)
================================================
FILE: Dassl.ProGrad.pytorch/tools/parse_test_res.py
================================================
"""
Goal
---
1. Read test results from log.txt files
2. Compute mean and std across different folders (seeds)
Usage
---
Assume the output files are saved under output/my_experiment,
which contains results of different seeds, e.g.,
my_experiment/
seed1/
log.txt
seed2/
log.txt
seed3/
log.txt
Run the following command from the root directory:
$ python tools/parse_test_res.py output/my_experiment
Add --ci95 to the argument if you wanna get 95% confidence
interval instead of standard deviation:
$ python tools/parse_test_res.py output/my_experiment --ci95
If my_experiment/ has the following structure,
my_experiment/
exp-1/
seed1/
log.txt
...
seed2/
log.txt
...
seed3/
log.txt
...
exp-2/
...
exp-3/
...
Run
$ python tools/parse_test_res.py output/my_experiment --multi-exp
"""
import re
import numpy as np
import os.path as osp
import argparse
from collections import OrderedDict, defaultdict
from dassl.utils import check_isfile, listdir_nohidden
def compute_ci95(res):
return 1.96 * np.std(res) / np.sqrt(len(res))
def parse_function(*metrics, directory="", args=None, end_signal=None):
print(f"Parsing files in {directory}")
subdirs = listdir_nohidden(directory, sort=True)
outputs = []
for subdir in subdirs:
fpath = osp.join(directory, subdir, "log.txt")
assert check_isfile(fpath)
good_to_go = False
output = OrderedDict()
with open(fpath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line == end_signal:
good_to_go = True
for metric in metrics:
match = metric["regex"].search(line)
if match and good_to_go:
if "file" not in output:
output["file"] = fpath
num = float(match.group(1))
name = metric["name"]
output[name] = num
if output:
outputs.append(output)
assert len(outputs) > 0, f"Nothing found in {directory}"
metrics_results = defaultdict(list)
for output in outputs:
msg = ""
for key, value in output.items():
if isinstance(value, float):
msg += f"{key}: {value:.2f}%. "
else:
msg += f"{key}: {value}. "
if key != "file":
metrics_results[key].append(value)
print(msg)
output_results = OrderedDict()
print("===")
print(f"Summary of directory: {directory}")
for key, values in metrics_results.items():
avg = np.mean(values)
std = compute_ci95(values) if args.ci95 else np.std(values)
print(f"* {key}: {avg:.2f}% +- {std:.2f}%")
output_results[key] = avg
print("===")
return output_results
def main(args, end_signal):
metric = {
"name": args.keyword,
"regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
}
if args.multi_exp:
final_results = defaultdict(list)
for directory in listdir_nohidden(args.directory, sort=True):
directory = osp.join(args.directory, directory)
results = parse_function(
metric, directory=directory, args=args, end_signal=end_signal
)
for key, value in results.items():
final_results[key].append(value)
print("Average performance")
for key, values in final_results.items():
avg = np.mean(values)
print(f"* {key}: {avg:.2f}%")
else:
parse_function(
metric, directory=args.directory, args=args, end_signal=end_signal
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("directory", type=str, help="path to directory")
parser.add_argument(
"--ci95",
action="store_true",
help=r"compute 95\% confidence interval"
)
parser.add_argument(
"--test-log", action="store_true", help="parse test-only logs"
)
parser.add_argument(
"--multi-exp", action="store_true", help="parse multiple experiments"
)
parser.add_argument(
"--keyword",
default="accuracy",
type=str,
help="which keyword to extract"
)
args = parser.parse_args()
end_signal = "Finished training"
if args.test_log:
end_signal = "=> result"
main(args, end_signal)
================================================
FILE: Dassl.ProGrad.pytorch/tools/replace_text.py
================================================
"""
Replace text in python files.
"""
import glob
import os.path as osp
import argparse
import fileinput
EXTENSION = ".py"
def is_python_file(filename):
ext = osp.splitext(filename)[1]
return ext == EXTENSION
def update_file(filename, text_to_search, replacement_text):
print("Processing {}".format(filename))
with fileinput.FileInput(filename, inplace=True, backup="") as file:
for line in file:
print(line.replace(text_to_search, replacement_text), end="")
def recursive_update(directory, text_to_search, replacement_text):
filenames = glob.glob(osp.join(directory, "*"))
for filename in filenames:
if osp.isfile(filename):
if not is_python_file(filename):
continue
update_file(filename, text_to_search, replacement_text)
elif osp.isdir(filename):
recursive_update(filename, text_to_search, replacement_text)
else:
raise NotImplementedError
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"file_or_dir", type=str, help="path to file or directory"
)
parser.add_argument("text_to_search", type=str, help="name to be replaced")
parser.add_argument("replacement_text", type=str, help="new name")
parser.add_argument(
"--ext", type=str, default=".py", help="file extension"
)
args = parser.parse_args()
file_or_dir = args.file_or_dir
text_to_search = args.text_to_search
replacement_text = args.replacement_text
extension = args.ext
global EXTENSION
EXTENSION = extension
if osp.isfile(file_or_dir):
if not is_python_file(file_or_dir):
return
update_file(file_or_dir, text_to_search, replacement_text)
elif osp.isdir(file_or_dir):
recursive_update(file_or_dir, text_to_search, replacement_text)
else:
raise NotImplementedError
if __name__ == "__main__":
main()
================================================
FILE: Dassl.ProGrad.pytorch/tools/train.py
================================================
import argparse
import torch
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
def print_args(args, cfg):
print("***************")
print("** Arguments **")
print("***************")
optkeys = list(args.__dict__.keys())
optkeys.sort()
for key in optkeys:
print("{}: {}".format(key, args.__dict__[key]))
print("************")
print("** Config **")
print("************")
print(cfg)
def reset_cfg(cfg, args):
if args.root:
cfg.DATASET.ROOT = args.root
if args.output_dir:
cfg.OUTPUT_DIR = args.output_dir
if args.resume:
cfg.RESUME = args.resume
if args.seed:
cfg.SEED = args.seed
if args.source_domains:
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
if args.target_domains:
cfg.DATASET.TARGET_DOMAINS = args.target_domains
if args.transforms:
cfg.INPUT.TRANSFORMS = args.transforms
if args.trainer:
cfg.TRAINER.NAME = args.trainer
if args.backbone:
cfg.MODEL.BACKBONE.NAME = args.backbone
if args.head:
cfg.MODEL.HEAD.NAME = args.head
def extend_cfg(cfg):
"""
Add new config variables.
E.g.
from yacs.config import CfgNode as CN
cfg.TRAINER.MY_MODEL = CN()
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
cfg.TRAINER.MY_MODEL.PARAM_C = False
"""
pass
def setup_cfg(args):
cfg = get_cfg_default()
extend_cfg(cfg)
# 1. From the dataset config file
if args.dataset_config_file:
cfg.merge_from_file(args.dataset_config_file)
# 2. From the method config file
if args.config_file:
cfg.merge_from_file(args.config_file)
# 3. From input arguments
reset_cfg(cfg, args)
# 4. From optional input arguments
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
print_args(args, cfg)
print("Collecting env info ...")
print("** System info **\n{}\n".format(collect_env_info()))
trainer = build_trainer(cfg)
if args.eval_only:
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
return
if not args.no_train:
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="", help="path to dataset")
parser.add_argument(
"--output-dir", type=str, default="", help="output directory"
)
parser.add_argument(
"--resume",
type=str,
default="",
help="checkpoint directory (from which the training resumes)",
)
parser.add_argument(
"--seed",
type=int,
default=-1,
help="only positive value enables a fixed seed"
)
parser.add_argument(
"--source-domains",
type=str,
nargs="+",
help="source domains for DA/DG"
)
parser.add_argument(
"--target-domains",
type=str,
nargs="+",
help="target domains for DA/DG"
)
parser.add_argument(
"--transforms", type=str, nargs="+", help="data augmentation methods"
)
parser.add_argument(
"--config-file", type=str, default="", help="path to config file"
)
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="path to config file for dataset setup",
)
parser.add_argument(
"--trainer", type=str, default="", help="name of trainer"
)
parser.add_argument(
"--backbone", type=str, default="", help="name of CNN backbone"
)
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument(
"--eval-only", action="store_true", help="evaluation only"
)
parser.add_argument(
"--model-dir",
type=str,
default="",
help="load model from this directory for eval-only mode",
)
parser.add_argument(
"--load-epoch",
type=int,
help="load model weights at this epoch for evaluation"
)
parser.add_argument(
"--no-train", action="store_true", help="do not call trainer.train()"
)
parser.add_argument(
"opts",
default=None,
nargs=argparse.REMAINDER,
help="modify config options using the command-line",
)
args = parser.parse_args()
main(args)
================================================
FILE: ProGrad.public/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Custom
output/
debug.sh
================================================
FILE: ProGrad.public/DATASETS.md
================================================
# How to install datasets
We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like
```
$DATA/
|–– imagenet/
|–– caltech-101/
|–– oxford_pets/
|–– stanford_cars/
```
If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download.
Datasets list:
- [ImageNet](#imagenet)
- [Caltech101](#caltech101)
- [OxfordPets](#oxfordpets)
- [StanfordCars](#stanfordcars)
- [Flowers102](#flowers102)
- [Food101](#food101)
- [FGVCAircraft](#fgvcaircraft)
- [SUN397](#sun397)
- [DTD](#dtd)
- [EuroSAT](#eurosat)
- [UCF101](#ucf101)
- [ImageNetV2](#imagenetv2)
- [ImageNet-Sketch](#imagenet-sketch)
- [ImageNet-A](#imagenet-a)
- [ImageNet-R](#imagenet-r)
The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us.
### ImageNet
- Create a folder named `imagenet/` under `$DATA`.
- Create `images/` under `imagenet/`.
- Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like
```
imagenet/
|–– images/
| |–– train/ # contains 1,000 folders like n01440764, n01443537, etc.
| |–– val/
```
- If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`.
- Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).
### Caltech101
- Create a folder named `caltech-101/` under `$DATA`.
- Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`.
- Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`.
The directory structure should look like
```
caltech-101/
|–– 101_ObjectCategories/
|–– split_zhou_Caltech101.json
```
### OxfordPets
- Create a folder named `oxford_pets/` under `$DATA`.
- Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz.
- Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz.
- Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing).
The directory structure should look like
```
oxford_pets/
|–– images/
|–– annotations/
|–– split_zhou_OxfordPets.json
```
### StanfordCars
- Create a folder named `stanford_cars/` under `$DATA`.
- Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz.
- Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz.
- Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz.
- Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat.
- Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing).
The directory structure should look like
```
stanford_cars/
|–– cars_test\
|–– cars_test_annos_withlabels.mat
|–– cars_train\
|–– devkit\
|–– split_zhou_StanfordCars.json
```
### Flowers102
- Create a folder named `oxford_flowers/` under `$DATA`.
- Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively.
- Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing).
- Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing).
The directory structure should look like
```
oxford_flowers/
|–– cat_to_name.json
|–– imagelabels.mat
|–– jpg/
|–– split_zhou_OxfordFlowers.json
```
### Food101
- Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`.
- Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing).
The directory structure should look like
```
food-101/
|–– images/
|–– license_agreement.txt
|–– meta/
|–– README.txt
|–– split_zhou_Food101.json
```
### FGVCAircraft
- Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz.
- Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`.
- Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`.
The directory structure should look like
```
fgvc_aircraft/
|–– images/
|–– ... # a bunch of .txt files
```
### SUN397
- Create a folder named `sun397/` under `$DATA`.
- Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz.
- Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip.
- Extract these files under `$DATA/sun397/`.
- Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing).
The directory structure should look like
```
sun397/
|–– SUN397/
|–– split_zhou_SUN397.json
|–– ... # a bunch of .txt files
```
### DTD
- Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`.
- Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing).
The directory structure should look like
```
dtd/
|–– images/
|–– imdb/
|–– labels/
|–– split_zhou_DescribableTextures.json
```
### EuroSAT
- Create a folder named `eurosat/` under `$DATA`.
- Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`.
- Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing).
The directory structure should look like
```
eurosat/
|–– 2750/
|–– split_zhou_EuroSAT.json
```
### UCF101
- Create a folder named `ucf101/` under `$DATA`.
- Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames.
- Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing).
The directory structure should look like
```
ucf101/
|–– UCF-101-midframes/
|–– split_zhou_UCF101.json
```
### ImageNetV2
- Create a folder named `imagenetv2/` under `$DATA`.
- Go to this github repo https://github.com/modestyachts/ImageNetV2.
- Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`.
- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`.
The directory structure should look like
```
imagenetv2/
|–– imagenetv2-matched-frequency-format-val/
|–– classnames.txt
```
### ImageNet-Sketch
- Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch.
- Extract the dataset to `$DATA/imagenet-sketch`.
- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`.
The directory structure should look like
```
imagenet-sketch/
|–– images/ # contains 1,000 folders whose names have the format of n*
|–– classnames.txt
```
### ImageNet-A
- Create a folder named `imagenet-adversarial/` under `$DATA`.
- Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`.
- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`.
The directory structure should look like
```
imagenet-adversarial/
|–– imagenet-a/ # contains 200 folders whose names have the format of n*
|–– classnames.txt
```
### ImageNet-R
- Create a folder named `imagenet-rendition/` under `$DATA`.
- Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`.
- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`.
The directory structure should look like
```
imagenet-rendition/
|–– imagenet-r/ # contains 200 folders whose names have the format of n*
|–– classnames.txt
```
================================================
FILE: ProGrad.public/LICENSE
================================================
MIT License
Copyright (c) 2021 Kaiyang Zhou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: ProGrad.public/README.md
================================================
# How to Run
## GPU memory needed
All the experiments is able to run on a single graphic card. However, **if you want to get results on ImageNet, the memory on any single graphic card should be larger than 24 GB.** Around 12 GB is enough for other datasets.
## How to Install
This code is built on top of the toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch). But we have some modification on it. So please install the provided Dassl.ProGrad.pytorch. Go the the folder Dassl.ProGrad.pytorch provided in the appendix, and prepare the environment as follows:
```
# Create a conda environment
conda create -n dassl python=3.7
# Activate the environment
conda activate dassl
# Install dependencies
pip install -r requirements.txt
# Install torch (version >= 1.7.1) and torchvision
# Please make sure you have installed the gpu version due to the speed.
# For example:
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
```
After that, run `pip install -r requirements.txt` under `ProGrad.public/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go.
Follow [DATASETS.md](DATASETS.md) to install the datasets.
## Few-shot setting on 11 datasets
Basic format:
```
bash main.sh ${DATASET_NAME} ${CONFIG_NAME} end ${CONTEXT_TOKENS_NUMBER} ${SHOTS} False
```
For example, to run 1, 2, 4, 8, and 16 shots on stanford_cars,
**CLIP + CoOp (M=16, end)**:
- 1 shot: `bash main.sh stanford_cars rn50_ep50 end 16 1 False`
- 2 shots: `bash main.sh stanford_cars rn50_ep100 end 16 2 False`
- 4 shots: `bash main.sh stanford_cars rn50_ep100 end 16 4 False`
- 8 shots: `bash main.sh stanford_cars rn50 end 16 8 False`
- 16 shots: `bash main.sh stanford_cars rn50 end 16 8 False`
**CLIP + CoOp + ProGrad**:
**Please take note that the 8-shots and 16-shots results on Flowers102, DTD, and EuroSAT are gotten with lambda as 0.8.** To get the results in our paper, please change the variable LAMBDA in prograd.sh from 1.0 to 0.8.
- 1 shot: `bash prograd.sh stanford_cars rn50_ep50 end 16 1 False`
- 2 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 2 False`
- 4 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 4 False`
- 8 shots: `bash prograd.sh stanford_cars rn50 end 16 8 False`
- 16 shots: `bash prograd.sh stanford_cars rn50 end 16 16 False`
```
output
|–– caltech101/
| |–– CoOp/
| | |–– rn50_16shots/
| | | |–– nctx16_cscFalse_ctpend/
| | | | |–– seed1/
| | | | |–– seed2/
| | | | |–– seed3/
| | |–– rn50_8shots/
| | | |–– nctx16_cscFalse_ctpend/
| | | | |–– seed1/
| | | | |–– seed2/
| | | | |–– seed3/
```
To calculate the average results for the folder `rn50_16shots/nctx16_cscFalse_ctpend/`, you can run
```bash
python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
```
Then, you will see something like this in your terminal
```bash
Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===
```
**How to visualize nearest words for the learned context tokens?** All you need is `interpret_prompt.py`. Say the learned tokens are saved in `a/b/c/prompt_learner/model.pth.tar` and you would like to see the top-3 nearest words for each token. In this case, run `python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3`
## Robustness to Distribution Shift
To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: `imagenetv2`, `imagenet-sketch`, `imagenet-a` and `imagenet-r`.
The command is provided in `scripts/eval.sh`. The key arguments are `--model-dir`, `--load-epoch` and `--eval-only`. `--model-dir` indicates the directory where the models are saved (i.e. the entire folder containing `log.txt`, the tensorboard file and `prompt_learner/`). `--load-epoch` tells the code to load the model saved at a specific epoch, like `--load-epoch 50` for ImageNet for more details).
For example, to evaluate `CLIP + CoOp (M=16, end)` on ImageNetV2, you can do
```bash
# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50
```
If you want to get the results of our method, simply change the TRAINER to `ProGrad`.
The default setting is `SHOTS=4`. Feel free to modify the script.
Again, you can use `parse_test_res.py` to automate the calculation of average performance. This time you should append `--test-log`, e.g., `python parse_test_res.py directory --test-log`.
## Zero-Shot CLIP
See `CoOp/scripts/zeroshot.sh`.
## Generalization From Base to New Classes
You will need `base2new_train_main.sh`, `base2new_test_main.sh`, `base2new_train_prograd.sh`, and `base2new_test_prograd.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have only one input argument, i.e., `DATASET`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`.
The scripts with postfix `prograd.sh` are used for our proposed method, while the ones with the postfix `main.sh` are used for CoOp.
Below we provide an example on how to evaluate the model on ImageNet.
```bash
bash base2new_train_prograd.sh stanford_cars
bash base2new_test_prograd.sh stanford_cars
```
**If you want to test results on ImageNet, remember to change the CFG from "rn50_ep100" to "rn50_ep50", and change the LOADEP from 100 to 50 in the corresponding script.**
When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation using the aforementioned commands, you would get
```
output
|–– base2new/
| |–– test_new/
| | |–– stanford_cars/
| | | |–– shots_16/
| | | | |–– CoCoOp/
| | | | | |–– rn50_ep100/
| | | | | | |–– seed1/
| | | | | | |–– seed2/
| | | | | | |–– seed3/
| |–– train_base/
| | |–– stanford_cars/
| | | |–– shots_16/
| | | | |–– CoCoOp/
| | | | | |–– rn50_ep100/
| | | | | | |–– seed1/
| | | | | | |–– seed2/
| | | | | | |–– seed3/
```
Then, to get the average performance on the base classes, run
```bash
python parse_test_res.py output/base2new/train_base/stanford_cars/shots_16/CoCoOp/rn50_ep100
```
To get the average performance on the new classes, run
```bash
python parse_test_res.py output/base2new/test_new/stanford_cars/shots_16/CoCoOp/rn50_ep100 --test-log
```
================================================
FILE: ProGrad.public/clip/__init__.py
================================================
from .clip import *
================================================
FILE: ProGrad.public/clip/clip.py
================================================
import hashlib
import os
import urllib
import warnings
from typing import Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50":
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101":
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4":
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16":
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32":
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16":
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target,
"rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target,
"wb") as output:
with tqdm(total=int(source.info().get("Content-Length")),
ncols=80,
unit='iB',
unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target,
"rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(
f"Model has been downloaded but the SHA256 checksum does not not match"
)
return download_target
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str,
device: Union[str, torch.device] = "cuda"
if torch.cuda.is_available() else "cpu",
jit=False):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path,
map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
)
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes("prim::Constant")
if "Device" in repr(n)
][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(
node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [
1, 2
]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]],
context_length: int = 77,
truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, :len(tokens)] = torch.tensor(tokens)
return result
================================================
FILE: ProGrad.public/clip/model.py
================================================
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([("-1", nn.AvgPool2d(stride)),
("0",
nn.Conv2d(inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3,
width // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(width // 2,
width // 2,
kernel_size=3,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2,
width,
kernel_size=3,
padding=1,
bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False,
attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)
self.transformer = Transformer(width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection,
std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias", "bias_k", "bias_v"
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2] for k in state_dict
if k.startswith(f"visual.layer{b}")))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] -
1)**0.5)
vision_patch_size = None
assert output_width**2 + 1 == state_dict[
"visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2] for k in state_dict
if k.startswith(f"transformer.resblocks")))
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()
================================================
FILE: ProGrad.public/clip/simple_tokenizer.py
================================================
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)),
"bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"),
ord("~") + 1)) + list(range(
ord("¡"),
ord("¬") + 1)) + list(range(ord("®"),
ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '', )
pairs = get_pairs(word)
if not pairs:
return token + ''
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text
]).decode('utf-8',
errors="replace").replace('', ' ')
return text
================================================
FILE: ProGrad.public/configs/datasets/caltech101.yaml
================================================
DATASET:
NAME: "Caltech101"
================================================
FILE: ProGrad.public/configs/datasets/dtd.yaml
================================================
DATASET:
NAME: "DescribableTextures"
================================================
FILE: ProGrad.public/configs/datasets/eurosat.yaml
================================================
DATASET:
NAME: "EuroSAT"
================================================
FILE: ProGrad.public/configs/datasets/fgvc_aircraft.yaml
================================================
DATASET:
NAME: "FGVCAircraft"
================================================
FILE: ProGrad.public/configs/datasets/food101.yaml
================================================
DATASET:
NAME: "Food101"
================================================
FILE: ProGrad.public/configs/datasets/imagenet.yaml
================================================
DATASET:
NAME: "ImageNet"
================================================
FILE: ProGrad.public/configs/datasets/imagenet_a.yaml
================================================
DATASET:
NAME: "ImageNetA"
================================================
FILE: ProGrad.public/configs/datasets/imagenet_r.yaml
================================================
DATASET:
NAME: "ImageNetR"
================================================
FILE: ProGrad.public/configs/datasets/imagenet_sketch.yaml
================================================
DATASET:
NAME: "ImageNetSketch"
================================================
FILE: ProGrad.public/configs/datasets/imagenetv2.yaml
================================================
DATASET:
NAME: "ImageNetV2"
================================================
FILE: ProGrad.public/configs/datasets/oxford_flowers.yaml
================================================
DATASET:
NAME: "OxfordFlowers"
================================================
FILE: ProGrad.public/configs/datasets/oxford_pets.yaml
================================================
DATASET:
NAME: "OxfordPets"
================================================
FILE: ProGrad.public/configs/datasets/stanford_cars.yaml
================================================
DATASET:
NAME: "StanfordCars"
================================================
FILE: ProGrad.public/configs/datasets/sun397.yaml
================================================
DATASET:
NAME: "SUN397"
================================================
FILE: ProGrad.public/configs/datasets/ucf101.yaml
================================================
DATASET:
NAME: "UCF101"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COCOOP:
N_CTX: 4
CTX_INIT: True
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_ep100_init.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 100
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COCOOP:
N_CTX: 16
CTX_INIT: True
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/rn50_ep50.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COCOOP:
N_CTX: 16
CTX_INIT: True
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "ViT-B/16"
TRAINER:
COCOOP:
N_CTX: 16
CTX_INIT: ""
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "ViT-B/16"
TRAINER:
COCOOP:
N_CTX: 4
CTX_INIT: ""
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "ViT-B/16"
TRAINER:
COCOOP:
N_CTX: 4
CTX_INIT: "a photo of a"
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 1
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 20
MODEL:
BACKBONE:
NAME: "ViT-B/16"
TRAINER:
COCOOP:
N_CTX: 8
CTX_INIT: ""
PREC: "fp16"
================================================
FILE: ProGrad.public/configs/trainers/CoOp/rn50.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 200
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
================================================
FILE: ProGrad.public/configs/trainers/CoOp/rn50_ep100.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 100
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
================================================
FILE: ProGrad.public/configs/trainers/CoOp/rn50_ep50.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
================================================
FILE: ProGrad.public/configs/trainers/CoOp/rn50_val.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 32
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
MODEL:
BACKBONE:
NAME: "RN50"
================================================
FILE: ProGrad.public/configs/trainers/ProGrad/rn50.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 200
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
LOSS:
NAME: "prograd"
T: 1.0
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COOP:
CTX_INIT: True
================================================
FILE: ProGrad.public/configs/trainers/ProGrad/rn50_ep100.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 100
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
LOSS:
NAME: "prograd"
T: 1.0
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COOP:
CTX_INIT: True
================================================
FILE: ProGrad.public/configs/trainers/ProGrad/rn50_ep50.yaml
================================================
DATALOADER:
TRAIN_X:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 100
NUM_WORKERS: 8
INPUT:
SIZE: (224, 224)
INTERPOLATION: "bicubic"
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
OPTIM:
NAME: "sgd"
LR: 0.002
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
WARMUP_CONS_LR: 1e-5
LOSS:
NAME: "prograd"
T: 1.0
TRAIN:
PRINT_FREQ: 5
MODEL:
BACKBONE:
NAME: "RN50"
TRAINER:
COOP:
CTX_INIT: True
================================================
FILE: ProGrad.public/datasets/__init__.py
================================================
================================================
FILE: ProGrad.public/datasets/caltech101.py
================================================
import os
import pickle
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
from .dtd import DescribableTextures as DTD
IGNORED = ["BACKGROUND_Google", "Faces_easy"]
NEW_CNAMES = {
"airplanes": "airplane",
"Faces": "face",
"Leopards": "leopard",
"Motorbikes": "motorbike",
}
@DATASET_REGISTRY.register()
class Caltech101(DatasetBase):
dataset_dir = "caltech-101"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_Caltech101.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
train, val, test = DTD.read_and_split_data(self.image_dir,
ignored=IGNORED,
new_cnames=NEW_CNAMES)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
================================================
FILE: ProGrad.public/datasets/dtd.py
================================================
import os
import pickle
import random
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden, mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class DescribableTextures(DatasetBase):
dataset_dir = "dtd"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_DescribableTextures.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
train, val, test = self.read_and_split_data(self.image_dir)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
@staticmethod
def read_and_split_data(image_dir,
p_trn=0.5,
p_val=0.2,
ignored=[],
new_cnames=None):
# The data are supposed to be organized into the following structure
# =============
# images/
# dog/
# cat/
# horse/
# =============
categories = listdir_nohidden(image_dir)
categories = [c for c in categories if c not in ignored]
categories.sort()
p_tst = 1 - p_trn - p_val
print(
f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test"
)
def _collate(ims, y, c):
items = []
for im in ims:
item = Datum(impath=im, label=y,
classname=c) # is already 0-based
items.append(item)
return items
train, val, test = [], [], []
for label, category in enumerate(categories):
category_dir = os.path.join(image_dir, category)
images = listdir_nohidden(category_dir)
images = [os.path.join(category_dir, im) for im in images]
random.shuffle(images)
n_total = len(images)
n_train = round(n_total * p_trn)
n_val = round(n_total * p_val)
n_test = n_total - n_train - n_val
assert n_train > 0 and n_val > 0 and n_test > 0
if new_cnames is not None and category in new_cnames:
category = new_cnames[category]
train.extend(_collate(images[:n_train], label, category))
val.extend(
_collate(images[n_train:n_train + n_val], label, category))
test.extend(_collate(images[n_train + n_val:], label, category))
return train, val, test
================================================
FILE: ProGrad.public/datasets/eurosat.py
================================================
import os
import pickle
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
from .dtd import DescribableTextures as DTD
NEW_CNAMES = {
"AnnualCrop": "Annual Crop Land",
"Forest": "Forest",
"HerbaceousVegetation": "Herbaceous Vegetation Land",
"Highway": "Highway or Road",
"Industrial": "Industrial Buildings",
"Pasture": "Pasture Land",
"PermanentCrop": "Permanent Crop Land",
"Residential": "Residential Buildings",
"River": "River",
"SeaLake": "Sea or Lake",
}
@DATASET_REGISTRY.register()
class EuroSAT(DatasetBase):
dataset_dir = "eurosat"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "2750")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_EuroSAT.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
train, val, test = DTD.read_and_split_data(self.image_dir,
new_cnames=NEW_CNAMES)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def update_classname(self, dataset_old):
dataset_new = []
for item_old in dataset_old:
cname_old = item_old.classname
cname_new = NEW_CLASSNAMES[cname_old]
item_new = Datum(impath=item_old.impath,
label=item_old.label,
classname=cname_new)
dataset_new.append(item_new)
return dataset_new
================================================
FILE: ProGrad.public/datasets/fgvc_aircraft.py
================================================
import os
import pickle
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class FGVCAircraft(DatasetBase):
dataset_dir = "fgvc_aircraft"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
classnames = []
with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
lines = f.readlines()
for line in lines:
classnames.append(line.strip())
cname2lab = {c: i for i, c in enumerate(classnames)}
train = self.read_data(cname2lab, "images_variant_train.txt")
val = self.read_data(cname2lab, "images_variant_val.txt")
test = self.read_data(cname2lab, "images_variant_test.txt")
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, cname2lab, split_file):
filepath = os.path.join(self.dataset_dir, split_file)
items = []
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(" ")
imname = line[0] + ".jpg"
classname = " ".join(line[1:])
impath = os.path.join(self.image_dir, imname)
label = cname2lab[classname]
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/food101.py
================================================
import os
import pickle
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
from .dtd import DescribableTextures as DTD
@DATASET_REGISTRY.register()
class Food101(DatasetBase):
dataset_dir = "food-101"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_Food101.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
train, val, test = DTD.read_and_split_data(self.image_dir)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
================================================
FILE: ProGrad.public/datasets/imagenet.py
================================================
import os
import pickle
from collections import OrderedDict
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden, mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class ImageNet(DatasetBase):
dataset_dir = "imagenet"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.preprocessed):
with open(self.preprocessed, "rb") as f:
preprocessed = pickle.load(f)
train = preprocessed["train"]
test = preprocessed["test"]
else:
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = self.read_classnames(text_file)
train = self.read_data(classnames, "train")
# Follow standard practice to perform evaluation on the val set
# Also used as the val set (so evaluate the last-step model)
test = self.read_data(classnames, "val")
preprocessed = {"train": train, "test": test}
with open(self.preprocessed, "wb") as f:
pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train = data["train"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
data = {"train": train}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, test = OxfordPets.subsample_classes(train,
test,
subsample=subsample)
super().__init__(train_x=train, val=test, test=test)
@staticmethod
def read_classnames(text_file):
"""Return a dictionary containing
key-value pairs of : .
"""
classnames = OrderedDict()
with open(text_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(" ")
folder = line[0]
classname = " ".join(line[1:])
classnames[folder] = classname
return classnames
def read_data(self, classnames, split_dir):
split_dir = os.path.join(self.image_dir, split_dir)
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
items = []
for label, folder in enumerate(folders):
imnames = listdir_nohidden(os.path.join(split_dir, folder))
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(split_dir, folder, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/imagenet_a.py
================================================
import os
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden
from .imagenet import ImageNet
TO_BE_IGNORED = ["README.txt"]
@DATASET_REGISTRY.register()
class ImageNetA(DatasetBase):
"""ImageNet-A(dversarial).
This dataset is used for testing only.
"""
dataset_dir = "imagenet-adversarial"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "imagenet-a")
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = ImageNet.read_classnames(text_file)
data = self.read_data(classnames)
super().__init__(train_x=data, test=data)
def read_data(self, classnames):
image_dir = self.image_dir
folders = listdir_nohidden(image_dir, sort=True)
folders = [f for f in folders if f not in TO_BE_IGNORED]
items = []
for label, folder in enumerate(folders):
imnames = listdir_nohidden(os.path.join(image_dir, folder))
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(image_dir, folder, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/imagenet_r.py
================================================
import os
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden
from .imagenet import ImageNet
TO_BE_IGNORED = ["README.txt"]
@DATASET_REGISTRY.register()
class ImageNetR(DatasetBase):
"""ImageNet-R(endition).
This dataset is used for testing only.
"""
dataset_dir = "imagenet-rendition"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "imagenet-r")
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = ImageNet.read_classnames(text_file)
data = self.read_data(classnames)
super().__init__(train_x=data, test=data)
def read_data(self, classnames):
image_dir = self.image_dir
folders = listdir_nohidden(image_dir, sort=True)
folders = [f for f in folders if f not in TO_BE_IGNORED]
items = []
for label, folder in enumerate(folders):
imnames = listdir_nohidden(os.path.join(image_dir, folder))
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(image_dir, folder, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/imagenet_sketch.py
================================================
import os
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden
from .imagenet import ImageNet
@DATASET_REGISTRY.register()
class ImageNetSketch(DatasetBase):
"""ImageNet-Sketch.
This dataset is used for testing only.
"""
dataset_dir = "imagenet-sketch"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = ImageNet.read_classnames(text_file)
data = self.read_data(classnames)
super().__init__(train_x=data, test=data)
def read_data(self, classnames):
image_dir = self.image_dir
folders = listdir_nohidden(image_dir, sort=True)
items = []
for label, folder in enumerate(folders):
imnames = listdir_nohidden(os.path.join(image_dir, folder))
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(image_dir, folder, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/imagenetv2.py
================================================
import os
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden
from .imagenet import ImageNet
@DATASET_REGISTRY.register()
class ImageNetV2(DatasetBase):
"""ImageNetV2.
This dataset is used for testing only.
"""
dataset_dir = "imagenetv2"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
image_dir = "imagenetv2-matched-frequency-format-val"
self.image_dir = os.path.join(self.dataset_dir, image_dir)
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = ImageNet.read_classnames(text_file)
data = self.read_data(classnames)
super().__init__(train_x=data, test=data)
def read_data(self, classnames):
image_dir = self.image_dir
folders = list(classnames.keys())
items = []
for label in range(1000):
class_dir = os.path.join(image_dir, str(label))
imnames = listdir_nohidden(class_dir)
folder = folders[label]
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(class_dir, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/oxford_flowers.py
================================================
import os
import pickle
import random
from scipy.io import loadmat
from collections import defaultdict
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import read_json, mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class OxfordFlowers(DatasetBase):
dataset_dir = "oxford_flowers"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "jpg")
self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
self.lab2cname_file = os.path.join(self.dataset_dir,
"cat_to_name.json")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_OxfordFlowers.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
train, val, test = self.read_data()
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self):
tracker = defaultdict(list)
label_file = loadmat(self.label_file)["labels"][0]
for i, label in enumerate(label_file):
imname = f"image_{str(i + 1).zfill(5)}.jpg"
impath = os.path.join(self.image_dir, imname)
label = int(label)
tracker[label].append(impath)
print("Splitting data into 50% train, 20% val, and 30% test")
def _collate(ims, y, c):
items = []
for im in ims:
item = Datum(impath=im, label=y - 1,
classname=c) # convert to 0-based label
items.append(item)
return items
lab2cname = read_json(self.lab2cname_file)
train, val, test = [], [], []
for label, impaths in tracker.items():
random.shuffle(impaths)
n_total = len(impaths)
n_train = round(n_total * 0.5)
n_val = round(n_total * 0.2)
n_test = n_total - n_train - n_val
assert n_train > 0 and n_val > 0 and n_test > 0
cname = lab2cname[str(label)]
train.extend(_collate(impaths[:n_train], label, cname))
val.extend(_collate(impaths[n_train:n_train + n_val], label,
cname))
test.extend(_collate(impaths[n_train + n_val:], label, cname))
return train, val, test
================================================
FILE: ProGrad.public/datasets/oxford_pets.py
================================================
import os
import pickle
import math
import random
from collections import defaultdict
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import read_json, write_json, mkdir_if_missing
@DATASET_REGISTRY.register()
class OxfordPets(DatasetBase):
dataset_dir = "oxford_pets"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "images")
self.anno_dir = os.path.join(self.dataset_dir, "annotations")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_OxfordPets.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = self.read_split(self.split_path, self.image_dir)
else:
trainval = self.read_data(split_file="trainval.txt")
test = self.read_data(split_file="test.txt")
train, val = self.split_trainval(trainval)
self.save_split(train, val, test, self.split_path, self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = self.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, split_file):
filepath = os.path.join(self.anno_dir, split_file)
items = []
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
imname, label, species, _ = line.split(" ")
breed = imname.split("_")[:-1]
breed = "_".join(breed)
breed = breed.lower()
imname += ".jpg"
impath = os.path.join(self.image_dir, imname)
label = int(label) - 1 # convert to 0-based index
item = Datum(impath=impath, label=label, classname=breed)
items.append(item)
return items
@staticmethod
def split_trainval(trainval, p_val=0.2):
p_trn = 1 - p_val
print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
tracker = defaultdict(list)
for idx, item in enumerate(trainval):
label = item.label
tracker[label].append(idx)
train, val = [], []
for label, idxs in tracker.items():
n_val = round(len(idxs) * p_val)
assert n_val > 0
random.shuffle(idxs)
for n, idx in enumerate(idxs):
item = trainval[idx]
if n < n_val:
val.append(item)
else:
train.append(item)
return train, val
@staticmethod
def save_split(train, val, test, filepath, path_prefix):
def _extract(items):
out = []
for item in items:
impath = item.impath
label = item.label
classname = item.classname
impath = impath.replace(path_prefix, "")
if impath.startswith("/"):
impath = impath[1:]
out.append((impath, label, classname))
return out
train = _extract(train)
val = _extract(val)
test = _extract(test)
split = {"train": train, "val": val, "test": test}
write_json(split, filepath)
print(f"Saved split to {filepath}")
@staticmethod
def read_split(filepath, path_prefix):
def _convert(items):
out = []
for impath, label, classname in items:
impath = os.path.join(path_prefix, impath)
item = Datum(impath=impath,
label=int(label),
classname=classname)
out.append(item)
return out
print(f"Reading split from {filepath}")
split = read_json(filepath)
train = _convert(split["train"])
val = _convert(split["val"])
test = _convert(split["test"])
return train, val, test
@staticmethod
def subsample_classes(*args, subsample="all"):
"""Divide classes into two groups. The first group
represents base classes while the second group represents
new classes.
Args:
args: a list of datasets, e.g. train, val and test.
subsample (str): what classes to subsample.
"""
assert subsample in ["all", "base", "new"]
if subsample == "all":
return args
dataset = args[0]
labels = set()
for item in dataset:
labels.add(item.label)
labels = list(labels)
labels.sort()
n = len(labels)
# Divide classes into two halves
m = math.ceil(n / 2)
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
if subsample == "base":
selected = labels[:m] # take the first half
else:
selected = labels[m:] # take the second half
relabeler = {y: y_new for y_new, y in enumerate(selected)}
output = []
for dataset in args:
dataset_new = []
for item in dataset:
if item.label not in selected:
continue
item_new = Datum(impath=item.impath,
label=relabeler[item.label],
classname=item.classname)
dataset_new.append(item_new)
output.append(dataset_new)
return output
================================================
FILE: ProGrad.public/datasets/stanford_cars.py
================================================
import os
import pickle
from scipy.io import loadmat
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class StanfordCars(DatasetBase):
dataset_dir = "stanford_cars"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_StanfordCars.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.dataset_dir)
else:
trainval_file = os.path.join(self.dataset_dir, "devkit",
"cars_train_annos.mat")
test_file = os.path.join(self.dataset_dir,
"cars_test_annos_withlabels.mat")
meta_file = os.path.join(self.dataset_dir, "devkit",
"cars_meta.mat")
trainval = self.read_data("cars_train", trainval_file, meta_file)
test = self.read_data("cars_test", test_file, meta_file)
train, val = OxfordPets.split_trainval(trainval)
OxfordPets.save_split(train, val, test, self.split_path,
self.dataset_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, image_dir, anno_file, meta_file):
anno_file = loadmat(anno_file)["annotations"][0]
meta_file = loadmat(meta_file)["class_names"][0]
items = []
for i in range(len(anno_file)):
imname = anno_file[i]["fname"][0]
impath = os.path.join(self.dataset_dir, image_dir, imname)
label = anno_file[i]["class"][0, 0]
label = int(label) - 1 # convert to 0-based index
classname = meta_file[label][0]
names = classname.split(" ")
year = names.pop(-1)
names.insert(0, year)
classname = " ".join(names)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/sun397.py
================================================
import os
import pickle
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class SUN397(DatasetBase):
dataset_dir = "sun397"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "SUN397")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_SUN397.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
classnames = []
with open(os.path.join(self.dataset_dir, "ClassName.txt"),
"r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()[1:] # remove /
classnames.append(line)
cname2lab = {c: i for i, c in enumerate(classnames)}
trainval = self.read_data(cname2lab, "Training_01.txt")
test = self.read_data(cname2lab, "Testing_01.txt")
train, val = OxfordPets.split_trainval(trainval)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, cname2lab, text_file):
text_file = os.path.join(self.dataset_dir, text_file)
items = []
with open(text_file, "r") as f:
lines = f.readlines()
for line in lines:
imname = line.strip()[1:] # remove /
classname = os.path.dirname(imname)
label = cname2lab[classname]
impath = os.path.join(self.image_dir, imname)
names = classname.split("/")[1:] # remove 1st letter
names = names[::-1] # put words like indoor/outdoor at first
classname = " ".join(names)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items
================================================
FILE: ProGrad.public/datasets/ucf101.py
================================================
import os
import pickle
import re
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import mkdir_if_missing
from .oxford_pets import OxfordPets
@DATASET_REGISTRY.register()
class UCF101(DatasetBase):
dataset_dir = "ucf101"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
self.split_path = os.path.join(self.dataset_dir,
"split_zhou_UCF101.json")
self.split_fewshot_dir = os.path.join(self.dataset_dir,
"split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
if os.path.exists(self.split_path):
train, val, test = OxfordPets.read_split(self.split_path,
self.image_dir)
else:
cname2lab = {}
filepath = os.path.join(self.dataset_dir,
"ucfTrainTestlist/classInd.txt")
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
label, classname = line.strip().split(" ")
label = int(label) - 1 # conver to 0-based index
cname2lab[classname] = label
trainval = self.read_data(cname2lab,
"ucfTrainTestlist/trainlist01.txt")
test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
train, val = OxfordPets.split_trainval(trainval)
OxfordPets.save_split(train, val, test, self.split_path,
self.image_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir,
f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(
f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train, val = data["train"], data["val"]
else:
train = self.generate_fewshot_dataset(train,
num_shots=num_shots)
val = self.generate_fewshot_dataset(val,
num_shots=min(
num_shots, 4))
data = {"train": train, "val": val}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val, test = OxfordPets.subsample_classes(train,
val,
test,
subsample=subsample)
super().__init__(train_x=train, val=val, test=test)
def read_data(self, cname2lab, text_file):
text_file = os.path.join(self.dataset_dir, text_file)
items = []
with open(text_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(" ")[0] # trainlist: filename, label
action, filename = line.split("/")
label = cname2lab[action]
elements = re.findall("[A-Z][^A-Z]*", action)
renamed_action = "_".join(elements)
filename = filename.replace(".avi", ".jpg")
impath = os.path.join(self.image_dir, renamed_action, filename)
item = Datum(impath=impath,
label=label,
classname=renamed_action)
items.append(item)
return items
================================================
FILE: ProGrad.public/interpret_prompt.py
================================================
import os
import sys
import argparse
import torch
from clip.simple_tokenizer import SimpleTokenizer
from clip import clip
def load_clip_to_cpu(backbone_name="RN50"):
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict())
return model
parser = argparse.ArgumentParser()
parser.add_argument("fpath", type=str, help="Path to the learned prompt")
parser.add_argument("topk", type=int, help="Select top-k similar words")
args = parser.parse_args()
fpath = args.fpath
topk = args.topk
assert os.path.exists(fpath)
print(f"Return the top-{topk} matched words")
tokenizer = SimpleTokenizer()
clip_model = load_clip_to_cpu()
token_embedding = clip_model.token_embedding.weight
print(f"Size of token embedding: {token_embedding.shape}")
prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"]
ctx = prompt_learner["ctx"]
ctx = ctx.float()
print(f"Size of context: {ctx.shape}")
if ctx.dim() == 2:
# Generic context
distance = torch.cdist(ctx, token_embedding)
print(f"Size of distance matrix: {distance.shape}")
sorted_idxs = torch.argsort(distance, dim=1)
sorted_idxs = sorted_idxs[:, :topk]
for m, idxs in enumerate(sorted_idxs):
words = [tokenizer.decoder[idx.item()] for idx in idxs]
dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
print(f"{m+1}: {words} {dist}")
elif ctx.dim() == 3:
# Class-specific context
raise NotImplementedError
================================================
FILE: ProGrad.public/lpclip/README.md
================================================
# Linear Probe CLIP
To run linear probe baselines, make sure that your current working directory is `lpclip/`.
Step 1: Extract Features using the CLIP Image Encoder
```bash
sh feat_extractor.sh
```
Step 2: Train few-shot linear probe
```bash
sh linear_probe.sh
```
We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep.
Note: please pull the latest Dassl (version >= `606a2c6`).
================================================
FILE: ProGrad.public/lpclip/feat_extractor.py
================================================
import os, argparse
import numpy as np
import torch
import sys
sys.path.append(os.path.abspath(".."))
from datasets.oxford_pets import OxfordPets
from datasets.oxford_flowers import OxfordFlowers
from datasets.fgvc_aircraft import FGVCAircraft
from datasets.dtd import DescribableTextures
from datasets.eurosat import EuroSAT
from datasets.stanford_cars import StanfordCars
from datasets.food101 import Food101
from datasets.sun397 import SUN397
from datasets.caltech101 import Caltech101
from datasets.ucf101 import UCF101
from datasets.imagenet import ImageNet
from datasets.imagenetv2 import ImageNetV2
from datasets.imagenet_sketch import ImageNetSketch
from datasets.imagenet_a import ImageNetA
from datasets.imagenet_r import ImageNetR
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.data.transforms import build_transform
from dassl.data import DatasetWrapper
import clip
# import pdb; pdb.set_trace()
def print_args(args, cfg):
print("***************")
print("** Arguments **")
print("***************")
optkeys = list(args.__dict__.keys())
optkeys.sort()
for key in optkeys:
print("{}: {}".format(key, args.__dict__[key]))
print("************")
print("** Config **")
print("************")
print(cfg)
def reset_cfg(cfg, args):
if args.root:
cfg.DATASET.ROOT = args.root
if args.output_dir:
cfg.OUTPUT_DIR = args.output_dir
if args.trainer:
cfg.TRAINER.NAME = args.trainer
if args.backbone:
cfg.MODEL.BACKBONE.NAME = args.backbone
if args.head:
cfg.MODEL.HEAD.NAME = args.head
def extend_cfg(cfg):
"""
Add new config variables.
E.g.
from yacs.config import CfgNode as CN
cfg.TRAINER.MY_MODEL = CN()
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
cfg.TRAINER.MY_MODEL.PARAM_C = False
"""
from yacs.config import CfgNode as CN
cfg.TRAINER.OURS = CN()
cfg.TRAINER.OURS.N_CTX = 10 # number of context vectors
cfg.TRAINER.OURS.CSC = False # class-specific context
cfg.TRAINER.OURS.CTX_INIT = "" # initialize context vectors with given words
cfg.TRAINER.OURS.WEIGHT_U = 0.1 # weight for the unsupervised loss
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
def setup_cfg(args):
cfg = get_cfg_default()
extend_cfg(cfg)
# 1. From the dataset config file
if args.dataset_config_file:
cfg.merge_from_file(args.dataset_config_file)
# 2. From the method config file
if args.config_file:
cfg.merge_from_file(args.config_file)
# 3. From input arguments
reset_cfg(cfg, args)
cfg.freeze()
return cfg
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
print_args(args, cfg)
print("Collecting env info ...")
print("** System info **\n{}\n".format(collect_env_info()))
######################################
# Setup DataLoader
######################################
dataset = eval(cfg.DATASET.NAME)(cfg)
if args.split == "train":
dataset_input = dataset.train_x
elif args.split == "val":
dataset_input = dataset.val
else:
dataset_input = dataset.test
tfm_train = build_transform(cfg, is_train=False)
data_loader = torch.utils.data.DataLoader(
DatasetWrapper(cfg, dataset_input, transform=tfm_train,
is_train=False),
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
sampler=None,
shuffle=False,
num_workers=cfg.DATALOADER.NUM_WORKERS,
drop_last=False,
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
)
########################################
# Setup Network
########################################
clip_model, _ = clip.load("RN50", "cuda", jit=False)
clip_model.eval()
###################################################################################################################
# Start Feature Extractor
feature_list = []
label_list = []
train_dataiter = iter(data_loader)
for train_step in range(1, len(train_dataiter) + 1):
batch = next(train_dataiter)
data = batch["img"].cuda()
feature = clip_model.visual(data)
feature = feature.cpu()
for idx in range(len(data)):
feature_list.append(feature[idx].tolist())
label_list.extend(batch["label"].tolist())
save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME)
os.makedirs(save_dir, exist_ok=True)
save_filename = f"{args.split}"
np.savez(
os.path.join(save_dir, save_filename),
feature_list=feature_list,
label_list=label_list,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="", help="path to dataset")
parser.add_argument("--output-dir",
type=str,
default="",
help="output directory")
parser.add_argument("--config-file",
type=str,
default="",
help="path to config file")
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="path to config file for dataset setup",
)
parser.add_argument("--num-shot",
type=int,
default=1,
help="number of shots")
parser.add_argument("--split",
type=str,
choices=["train", "val", "test"],
help="which split")
parser.add_argument("--trainer",
type=str,
default="",
help="name of trainer")
parser.add_argument("--backbone",
type=str,
default="",
help="name of CNN backbone")
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument("--seed",
type=int,
default=-1,
help="only positive value enables a fixed seed")
parser.add_argument("--eval-only",
action="store_true",
help="evaluation only")
args = parser.parse_args()
main(args)
================================================
FILE: ProGrad.public/lpclip/feat_extractor.sh
================================================
# sh feat_extractor.sh
DATA=/data1/CoOpData
OUTPUT='/data1/CoOpData/clip_feat/'
SEED=1
GPULIST=(0 1 2 3)
GPUIDX=0
# oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet
# imagenet oxford_pets oxford_flowers stanford_cars food101 caltech101
for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r
do
for SPLIT in train val test
do
while true
do
sleep 10
let STATIDX=GPULIST[GPUIDX]+2
stat=$(gpustat | awk '{print $11}' | sed -n ${STATIDX}'p')
if [ "$stat" -lt 20 ]
then
break
fi
let GPUIDX=(GPUIDX+1)%${#GPULIST[@]}
echo $GPUIDX'N'
done
CUDA_VISIBLE_DEVICES=${GPULIST[${GPUIDX}]} python feat_extractor.py \
--split ${SPLIT} \
--root ${DATA} \
--seed ${SEED} \
--dataset-config-file ../configs/datasets/${DATASET}.yaml \
--config-file ../configs/trainers/CoOp/rn50_val.yaml \
--output-dir ${OUTPUT} \
--eval-only &
sleep 10
done
done
================================================
FILE: ProGrad.public/lpclip/linear_probe.py
================================================
import numpy as np
import os
from sklearn.linear_model import LogisticRegression
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="", help="path to dataset")
parser.add_argument("--num_step", type=int, default=8, help="number of steps")
parser.add_argument("--num_run", type=int, default=10, help="number of runs")
parser.add_argument("--feature_dir",
type=str,
default="clip_feat",
help="feature dir path")
args = parser.parse_args()
dataset = args.dataset
dataset_path = os.path.join(f"{args.feature_dir}", dataset)
train_file = np.load(os.path.join(dataset_path, "train.npz"))
train_feature, train_label = train_file["feature_list"], train_file[
"label_list"]
val_file = np.load(os.path.join(dataset_path, "val.npz"))
val_feature, val_label = val_file["feature_list"], val_file["label_list"]
test_file = np.load(os.path.join(dataset_path, "test.npz"))
test_feature, test_label = test_file["feature_list"], test_file["label_list"]
os.makedirs("report", exist_ok=True)
val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}
# for num_shot in [1, 2, 4, 8, 16]:
for num_shot in [4, 16]:
test_acc_step_list = np.zeros([args.num_run, args.num_step])
for seed in range(1, args.num_run + 1):
np.random.seed(seed)
print(
f"-- Seed: {seed} --------------------------------------------------------------"
)
# Sampling
all_label_list = np.unique(train_label)
selected_idx_list = []
for label in all_label_list:
label_collection = np.where(train_label == label)[0]
selected_idx = np.random.choice(label_collection,
size=num_shot,
replace=False)
selected_idx_list.extend(selected_idx)
fewshot_train_feature = train_feature[selected_idx_list]
fewshot_train_label = train_label[selected_idx_list]
val_num_shot = val_shot_list[num_shot]
val_selected_idx_list = []
for label in all_label_list:
label_collection = np.where(val_label == label)[0]
selected_idx = np.random.choice(label_collection,
size=val_num_shot,
replace=False)
val_selected_idx_list.extend(selected_idx)
fewshot_val_feature = val_feature[val_selected_idx_list]
fewshot_val_label = val_label[val_selected_idx_list]
# search initialization
search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]
acc_list = []
for c_weight in search_list:
clf = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_weight).fit(fewshot_train_feature,
fewshot_train_label)
pred = clf.predict(fewshot_val_feature)
acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)
acc_list.append(acc_val)
print(acc_list, flush=True)
# binary search
peak_idx = np.argmax(acc_list)
c_peak = search_list[peak_idx]
c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak
def binary_search(c_left, c_right, seed, step, test_acc_step_list):
clf_left = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_left).fit(
fewshot_train_feature,
fewshot_train_label)
pred_left = clf_left.predict(fewshot_val_feature)
acc_left = sum(
pred_left == fewshot_val_label) / len(fewshot_val_label)
print("Val accuracy (Left): {:.2f}".format(100 * acc_left),
flush=True)
clf_right = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_right).fit(
fewshot_train_feature,
fewshot_train_label)
pred_right = clf_right.predict(fewshot_val_feature)
acc_right = sum(
pred_right == fewshot_val_label) / len(fewshot_val_label)
print("Val accuracy (Right): {:.2f}".format(100 * acc_right),
flush=True)
# find maximum and update ranges
if acc_left < acc_right:
c_final = c_right
clf_final = clf_right
# range for the next step
c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))
c_right = np.log10(c_right)
else:
c_final = c_left
clf_final = clf_left
# range for the next step
c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))
c_left = np.log10(c_left)
pred = clf_final.predict(test_feature)
test_acc = 100 * sum(pred == test_label) / len(pred)
print("Test Accuracy: {:.2f}".format(test_acc), flush=True)
test_acc_step_list[seed - 1, step] = test_acc
saveline = "{}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format(
dataset, seed, num_shot, c_final, test_acc)
with open(
"./report/{}_s{}r{}_details.txt".format(
'clip_feat', args.num_step, args.num_run),
"a+",
) as writer:
writer.write(saveline)
return (
np.power(10, c_left),
np.power(10, c_right),
seed,
step,
test_acc_step_list,
)
for step in range(args.num_step):
print(
f"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}",
flush=True,
)
c_left, c_right, seed, step, test_acc_step_list = binary_search(
c_left, c_right, seed, step, test_acc_step_list)
# save results of last step
test_acc_list = test_acc_step_list[:, -1]
acc_mean = np.mean(test_acc_list)
acc_std = np.std(test_acc_list)
save_line = "{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format(
dataset, num_shot, acc_mean, acc_std)
print(save_line, flush=True)
with open(
"./report/{}_s{}r{}.txt".format('clip_feat', args.num_step,
args.num_run),
"a+",
) as writer:
writer.write(save_line)
================================================
FILE: ProGrad.public/lpclip/linear_probe.sh
================================================
feature_dir=/data1/CoOpData/clip_feat/
# ImageNet OxfordPets OxfordFlowers StanfordCars Food101 Caltech101
for DATASET in ImageNet
do
python linear_probe.py \
--dataset ${DATASET} \
--feature_dir ${feature_dir} \
--num_step 8 \
--num_run 3
done
================================================
FILE: ProGrad.public/lpclip/linear_probe_transfer.py
================================================
import numpy as np
import os
from sklearn.linear_model import LogisticRegression
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument("--train_dataset",
# type=str,
# default="",
# help="path to train dataset")
# parser.add_argument("--test_dataset",
# type=str,
# default="",
# help="path to test dataset")
parser.add_argument("--num_step", type=int, default=8, help="number of steps")
parser.add_argument("--num_run", type=int, default=10, help="number of runs")
parser.add_argument("--feature_dir",
type=str,
default="/data1/CoOpData/clip_feat/",
help="feature dir path")
args = parser.parse_args()
train_dataset = 'ImageNet'
train_dataset_path = os.path.join(f"{args.feature_dir}", train_dataset)
test_datasets = ['ImageNetV2', 'ImageNetSketch', 'ImageNetR', 'ImageNetA']
test_dataset_paths = [
os.path.join(f"{args.feature_dir}", test_dataset)
for test_dataset in test_datasets
]
train_file = np.load(os.path.join(train_dataset_path, "train.npz"))
train_feature, train_label = train_file["feature_list"], train_file[
"label_list"]
val_file = np.load(os.path.join(train_dataset_path, "val.npz"))
val_feature, val_label = val_file["feature_list"], val_file["label_list"]
test_files = [
np.load(os.path.join(test_dataset_path, "test.npz"))
for test_dataset_path in test_dataset_paths
]
test_features, test_labels = [
test_file["feature_list"] for test_file in test_files
], [test_file["label_list"] for test_file in test_files]
os.makedirs("report", exist_ok=True)
val_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}
# for num_shot in [1, 2, 4, 8, 16]:
for num_shot in [16]:
test_acc_step_list = np.zeros(
[len(test_datasets), args.num_run, args.num_step])
for seed in range(1, args.num_run + 1):
np.random.seed(seed)
print(
f"-- Seed: {seed} --------------------------------------------------------------"
)
# Sampling
all_label_list = np.unique(train_label)
selected_idx_list = []
for label in all_label_list:
label_collection = np.where(train_label == label)[0]
selected_idx = np.random.choice(label_collection,
size=num_shot,
replace=False)
selected_idx_list.extend(selected_idx)
fewshot_train_feature = train_feature[selected_idx_list]
fewshot_train_label = train_label[selected_idx_list]
val_num_shot = val_shot_list[num_shot]
val_selected_idx_list = []
for label in all_label_list:
label_collection = np.where(val_label == label)[0]
selected_idx = np.random.choice(label_collection,
size=val_num_shot,
replace=False)
val_selected_idx_list.extend(selected_idx)
fewshot_val_feature = val_feature[val_selected_idx_list]
fewshot_val_label = val_label[val_selected_idx_list]
# search initialization
search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]
acc_list = []
for c_weight in search_list:
clf = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_weight).fit(fewshot_train_feature,
fewshot_train_label)
pred = clf.predict(fewshot_val_feature)
acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)
acc_list.append(acc_val)
print(acc_list, flush=True)
# binary search
peak_idx = np.argmax(acc_list)
c_peak = search_list[peak_idx]
c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak
def binary_search(c_left, c_right, seed, step, test_acc_step_list):
clf_left = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_left).fit(
fewshot_train_feature,
fewshot_train_label)
pred_left = clf_left.predict(fewshot_val_feature)
acc_left = sum(
pred_left == fewshot_val_label) / len(fewshot_val_label)
print("Val accuracy (Left): {:.2f}".format(100 * acc_left),
flush=True)
clf_right = LogisticRegression(solver="lbfgs",
max_iter=1000,
penalty="l2",
C=c_right).fit(
fewshot_train_feature,
fewshot_train_label)
pred_right = clf_right.predict(fewshot_val_feature)
acc_right = sum(
pred_right == fewshot_val_label) / len(fewshot_val_label)
print("Val accuracy (Right): {:.2f}".format(100 * acc_right),
flush=True)
# find maximum and update ranges
if acc_left < acc_right:
c_final = c_right
clf_final = clf_right
# range for the next step
c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))
c_right = np.log10(c_right)
else:
c_final = c_left
clf_final = clf_left
# range for the next step
c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))
c_left = np.log10(c_left)
for i, (test_feature, test_label, test_dataset) in enumerate(
zip(test_features, test_labels, test_datasets)):
pred = clf_final.predict(test_feature)
test_acc = 100 * sum(pred == test_label) / len(pred)
print("Test Accuracy: {:.2f}".format(test_acc), flush=True)
test_acc_step_list[i, seed - 1, step] = test_acc
saveline = "{}, {}, seed {}, {} shot, weight {}, test_acc {:.2f}\n".format(
train_dataset, test_dataset, seed, num_shot, c_final,
test_acc)
with open(
"./report/{}_s{}r{}_details.txt".format(
'clip_feat', args.num_step, args.num_run),
"a+",
) as writer:
writer.write(saveline)
return (
np.power(10, c_left),
np.power(10, c_right),
seed,
step,
test_acc_step_list,
)
for step in range(args.num_step):
print(
f"{train_dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}",
flush=True,
)
c_left, c_right, seed, step, test_acc_step_list = binary_search(
c_left, c_right, seed, step, test_acc_step_list)
# save results of last step
test_acc_list = test_acc_step_list[:, :, -1]
acc_mean = np.mean(test_acc_list, dim=-1)
acc_std = np.std(test_acc_list, dim=-1)
for i in range(len(test_datasets)):
save_line = "{}, {}, {} Shot, Test acc stat: {:.2f} ({:.2f})\n".format(
train_dataset, test_datasets[i], num_shot, acc_mean[i], acc_std[i])
print(save_line, flush=True)
with open(
"./report/{}_s{}r{}.txt".format('clip_feat', args.num_step,
args.num_run),
"a+",
) as writer:
writer.write(save_line)
================================================
FILE: ProGrad.public/parse_test_res.py
================================================
"""
Goal
---
1. Read test results from log.txt files
2. Compute mean and std across different folders (seeds)
Usage
---
Assume the output files are saved under output/my_experiment,
which contains results of different seeds, e.g.,
my_experiment/
seed1/
log.txt
seed2/
log.txt
seed3/
log.txt
Run the following command from the root directory:
$ python tools/parse_test_res.py output/my_experiment
Add --ci95 to the argument if you wanna get 95% confidence
interval instead of standard deviation:
$ python tools/parse_test_res.py output/my_experiment --ci95
If my_experiment/ has the following structure,
my_experiment/
exp-1/
seed1/
log.txt
...
seed2/
log.txt
...
seed3/
log.txt
...
exp-2/
...
exp-3/
...
Run
$ python tools/parse_test_res.py output/my_experiment --multi-exp
"""
import re
import numpy as np
import os.path as osp
import argparse
from collections import OrderedDict, defaultdict
from dassl.utils import check_isfile, listdir_nohidden
def compute_ci95(res):
return 1.96 * np.std(res) / np.sqrt(len(res))
def parse_function(*metrics, directory="", args=None, end_signal=None):
print(f"Parsing files in {directory}")
subdirs = listdir_nohidden(directory, sort=True)
outputs = []
for subdir in subdirs:
fpath = osp.join(directory, subdir, "log.txt")
assert check_isfile(fpath)
good_to_go = False
output = OrderedDict()
with open(fpath, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line == end_signal:
good_to_go = True
for metric in metrics:
match = metric["regex"].search(line)
if match and good_to_go:
if "file" not in output:
output["file"] = fpath
num = float(match.group(1))
name = metric["name"]
output[name] = num
if output:
outputs.append(output)
assert len(outputs) > 0, f"Nothing found in {directory}"
metrics_results = defaultdict(list)
for output in outputs:
msg = ""
for key, value in output.items():
if isinstance(value, float):
msg += f"{key}: {value:.2f}%. "
else:
msg += f"{key}: {value}. "
if key != "file":
metrics_results[key].append(value)
print(msg)
output_results = OrderedDict()
print("===")
print(f"Summary of directory: {directory}")
for key, values in metrics_results.items():
avg = np.mean(values)
std = compute_ci95(values) if args.ci95 else np.std(values)
print(f"* {key}: {avg:.2f}% +- {std:.2f}%")
output_results[key] = avg
print("===")
return output_results
def main(args, end_signal):
metric = {
"name": args.keyword,
"regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
}
if args.multi_exp:
final_results = defaultdict(list)
for directory in listdir_nohidden(args.directory, sort=True):
directory = osp.join(args.directory, directory)
results = parse_function(metric,
directory=directory,
args=args,
end_signal=end_signal)
for key, value in results.items():
final_results[key].append(value)
print("Average performance")
for key, values in final_results.items():
avg = np.mean(values)
print(f"* {key}: {avg:.2f}%")
else:
parse_function(metric,
directory=args.directory,
args=args,
end_signal=end_signal)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("directory", type=str, help="path to directory")
parser.add_argument("--ci95",
action="store_true",
help=r"compute 95\% confidence interval")
parser.add_argument("--test-log",
action="store_true",
help="parse test-only logs")
parser.add_argument("--multi-exp",
action="store_true",
help="parse multiple experiments")
parser.add_argument("--keyword",
default="accuracy",
type=str,
help="which keyword to extract")
args = parser.parse_args()
end_signal = "Finished training"
if args.test_log:
end_signal = "=> result"
main(args, end_signal)
================================================
FILE: ProGrad.public/requirements.txt
================================================
ftfy
regex
tqdm
================================================
FILE: ProGrad.public/scripts/base2new_test_main.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=CoOp
DATASET=$1
CFG=rn50_ep100 # config file
CTP=end # class token position (end or middle)
NCTX=16 # number of context tokens
SHOTS=4 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
LOADEP=100
SUB=new
for SEED in 1 2 3
do
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB}
fi
done
================================================
FILE: ProGrad.public/scripts/base2new_test_prograd.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=ProGrad
DATASET=$1
CFG=rn50_ep100 # config file
CTP=end # class token position (end or middle)
NCTX=16 # number of context tokens
SHOTS=4 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
LAMBDA=1.0
LOADEP=100
SUB=new
for SEED in 1 2 3
do
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
LOSS.LAMBDA ${LAMBDA} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB}
fi
done
================================================
FILE: ProGrad.public/scripts/base2new_train_main.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=CoOp
DATASET=$1
CFG=rn50_ep100 # config file
CTP=end # class token position (end or middle)
NCTX=16 # number of context tokens
SHOTS=4 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
for SEED in 1 2 3
do
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base
fi
done
================================================
FILE: ProGrad.public/scripts/base2new_train_prograd.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=ProGrad
DATASET=$1
CFG=rn50_ep100 # config file
CTP=end # class token position (end or middle)
NCTX=16 # number of context tokens
SHOTS=4 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
LAMBDA=1.0
for SEED in 1 2 3
do
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
LOSS.LAMBDA ${LAMBDA} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base
fi
done
================================================
FILE: ProGrad.public/scripts/eval.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/path/to/datasets
TRAINER=CoOp
SHOTS=4
NCTX=16
CSC=False
CTP=end
DATASET=$1
CFG=$2
for SEED in 1 2 3
do
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \
--model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \
--load-epoch 50 \
--eval-only \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP}
done
================================================
FILE: ProGrad.public/scripts/main.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=CoOp
DATASET=$1
CFG=$2 # config file
CTP=$3 # class token position (end or middle)
NCTX=$4 # number of context tokens
SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
CSC=$6 # class-specific context (False or True)
for SEED in 1 2 3
do
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done
================================================
FILE: ProGrad.public/scripts/prograd.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData/
TRAINER=ProGrad
DATASET=$1
CFG=$2 # config file
CTP=$3 # class token position (end or middle)
NCTX=$4 # number of context tokens
SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
CSC=$6 # class-specific context (False or True)
LAMBDA=1.0
for SEED in 1 2 3
do
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
LOSS.LAMBDA ${LAMBDA} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done
================================================
FILE: ProGrad.public/scripts/zeroshot.sh
================================================
#!/bin/bash
cd ..
# custom config
DATA=/data1/CoOpData
TRAINER=ZeroshotCLIP
DATASET=$1
CFG=$2 # rn50, rn101, vit_b32 or vit_b16
python train.py \
--root ${DATA} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/CoOp/${CFG}.yaml \
--output-dir output/${TRAINER}/${CFG}/${DATASET} \
--eval-only
================================================
FILE: ProGrad.public/train.py
================================================
import argparse
import torch
import time
import os
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet
import datasets.imagenet_sketch
import datasets.imagenetv2
import datasets.imagenet_a
import datasets.imagenet_r
import trainers.coop
import trainers.cocoop
import trainers.zsclip
import trainers.prograd
def print_args(args, cfg):
print("***************")
print("** Arguments **")
print("***************")
optkeys = list(args.__dict__.keys())
optkeys.sort()
for key in optkeys:
print("{}: {}".format(key, args.__dict__[key]))
print("************")
print("** Config **")
print("************")
print(cfg)
def reset_cfg(cfg, args):
if args.root:
cfg.DATASET.ROOT = args.root
if args.output_dir:
cfg.OUTPUT_DIR = args.output_dir
if args.resume:
cfg.RESUME = args.resume
if args.seed:
cfg.SEED = args.seed
if args.source_domains:
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
if args.target_domains:
cfg.DATASET.TARGET_DOMAINS = args.target_domains
if args.transforms:
cfg.INPUT.TRANSFORMS = args.transforms
if args.trainer:
cfg.TRAINER.NAME = args.trainer
if args.backbone:
cfg.MODEL.BACKBONE.NAME = args.backbone
if args.head:
cfg.MODEL.HEAD.NAME = args.head
def extend_cfg(cfg):
"""
Add new config variables.
E.g.
from yacs.config import CfgNode as CN
cfg.TRAINER.MY_MODEL = CN()
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
cfg.TRAINER.MY_MODEL.PARAM_C = False
"""
from yacs.config import CfgNode as CN
cfg.TRAINER.COOP = CN()
cfg.TRAINER.COOP.ALPHA = 1.0
cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
cfg.TRAINER.COOP.CSC = False # class-specific context
cfg.TRAINER.COOP.CTX_INIT = False # initialization words
cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
cfg.TRAINER.COCOOP = CN()
cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
cfg.TRAINER.COCOOP.CTX_INIT = False # initialization words
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
"""
Add new config
"""
cfg.LOSS = CN()
cfg.LOSS.GM = False
cfg.LOSS.NAME = ""
cfg.LOSS.ALPHA = 0.
cfg.LOSS.T = 1.
cfg.LOSS.LAMBDA = 1.
def setup_cfg(args):
cfg = get_cfg_default()
extend_cfg(cfg)
# 1. From the dataset config file
if args.dataset_config_file:
cfg.merge_from_file(args.dataset_config_file)
# 2. From the method config file
if args.config_file:
cfg.merge_from_file(args.config_file)
# 3. From input arguments
reset_cfg(cfg, args)
# 4. From optional input arguments
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
print_args(args, cfg)
print("Collecting env info ...")
print("** System info **\n{}\n".format(collect_env_info()))
trainer = build_trainer(cfg)
if args.eval_only:
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
return
if not args.no_train:
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="", help="path to dataset")
parser.add_argument("--output-dir",
type=str,
default="",
help="output directory")
parser.add_argument(
"--resume",
type=str,
default="",
help="checkpoint directory (from which the training resumes)",
)
parser.add_argument("--seed",
type=int,
default=-1,
help="only positive value enables a fixed seed")
parser.add_argument("--source-domains",
type=str,
nargs="+",
help="source domains for DA/DG")
parser.add_argument("--target-domains",
type=str,
nargs="+",
help="target domains for DA/DG")
parser.add_argument("--transforms",
type=str,
nargs="+",
help="data augmentation methods")
parser.add_argument("--config-file",
type=str,
default="",
help="path to config file")
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="path to config file for dataset setup",
)
parser.add_argument("--trainer",
type=str,
default="",
help="name of trainer")
parser.add_argument("--backbone",
type=str,
default="",
help="name of CNN backbone")
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument("--eval-only",
action="store_true",
help="evaluation only")
parser.add_argument(
"--model-dir",
type=str,
default="",
help="load model from this directory for eval-only mode",
)
parser.add_argument("--load-epoch",
type=int,
help="load model weights at this epoch for evaluation")
parser.add_argument("--no-train",
action="store_true",
help="do not call trainer.train()")
parser.add_argument(
"opts",
default=None,
nargs=argparse.REMAINDER,
help="modify config options using the command-line",
)
args = parser.parse_args()
main(args)
================================================
FILE: ProGrad.public/trainers/__init__.py
================================================
================================================
FILE: ProGrad.public/trainers/cocoop.py
================================================
import os.path as osp
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict())
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.COCOOP.N_CTX
ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
vis_dim = clip_model.visual.output_dim
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init:
ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
ctx_init = ctx_init.replace(" {}.", "")
ctx_init = ctx_init.replace("_", " ")
prompt_n_ctx = len(ctx_init.split(" "))
assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})"
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)
ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +
prompt_n_ctx, :]
prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx))
prompt_prefix = f"{prompt_prefix} {ctx_init}"
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
self.ctx = nn.Parameter(ctx_vectors)
self.meta_net = nn.Sequential(
OrderedDict([("linear1", nn.Linear(vis_dim, vis_dim // 16)),
("relu", nn.ReLU(inplace=True)),
("linear2", nn.Linear(vis_dim // 16, ctx_dim))]))
if cfg.TRAINER.COCOOP.PREC == "fp16":
self.meta_net.half()
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p)
for p in prompts]) # (n_cls, n_tkn)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(
dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix",
embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
# dim0 is either batch_size (during training) or n_cls (during testing)
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
if label is not None:
prefix = prefix[label]
suffix = suffix[label]
prompts = torch.cat(
[
prefix, # (dim0, 1, dim)
ctx, # (dim0, n_ctx, dim)
suffix, # (dim0, *, dim)
],
dim=1,
)
return prompts
def forward(self, im_features):
prefix = self.token_prefix
suffix = self.token_suffix
ctx = self.ctx # (n_ctx, ctx_dim)
bias = self.meta_net(im_features) # (batch, ctx_dim)
bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
# Use instance-conditioned context tokens for all classes
prompts = []
for ctx_shifted_i in ctx_shifted:
ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
pts_i = self.construct_prompts(ctx_i, prefix,
suffix) # (n_cls, n_tkn, ctx_dim)
prompts.append(pts_i)
prompts = torch.stack(prompts)
return prompts
CUSTOM_TEMPLATES = {
# "OxfordPets": "a photo of a {}, a type of pet.",
"OxfordPets": "a type of pet, a photo of a {}.",
# "OxfordFlowers": "a photo of a {}, a type of flower.",
"OxfordFlowers": "a type of flower, a photo of a {}.",
"FGVCAircraft": "a type of aircraft, a photo of a {}.",
"DescribableTextures": "a texture of {}.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
# "Food101": "a photo of {}, a type of food.",
"Food101": "a type of food, a photo of {}.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
prompts = self.prompt_learner(image_features)
logits = []
for pts_i, imf_i in zip(prompts, image_features):
text_features = self.text_encoder(pts_i, tokenized_prompts)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
l_i = logit_scale * imf_i @ text_features.t()
logits.append(l_i)
logits = torch.stack(logits)
if self.prompt_learner.training:
return F.cross_entropy(logits, label)
return logits
@TRAINER_REGISTRY.register()
class CoCoOp(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
param.requires_grad_(False)
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.prompt_learner,
cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner,
self.optim, self.sched)
self.scaler = GradScaler(
) if cfg.TRAINER.COCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
)
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.COCOOP.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss = model(image, label)
optim.zero_grad()
loss.backward()
optim.step()
loss_summary = {"loss": loss.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained model is given"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
================================================
FILE: ProGrad.public/trainers/coop.py
================================================
import os.path as osp
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict())
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.COOP.N_CTX
ctx_init = cfg.TRAINER.COOP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
# if ctx_init:
# # use given words to initialize context vectors
# ctx_init = ctx_init.replace("_", " ")
# n_ctx = len(ctx_init.split(" "))
# prompt = clip.tokenize(ctx_init)
# with torch.no_grad():
# embedding = clip_model.token_embedding(prompt).type(dtype)
# ctx_vectors = embedding[0, 1:1 + n_ctx, :]
# prompt_prefix = ctx_init
if ctx_init:
ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
ctx_init = ctx_init.replace(" {}.", "")
ctx_init = ctx_init.replace("_", " ")
prompt_n_ctx = len(ctx_init.split(" "))
assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})"
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)
ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +
prompt_n_ctx, :]
prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx))
prompt_prefix = f"{prompt_prefix} {ctx_init}"
else:
# random initialization
if cfg.TRAINER.COOP.CSC:
print("Initializing class-specific contexts")
ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
else:
print("Initializing a generic context")
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(
dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix",
embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
if self.class_token_position == "end":
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=1,
)
elif self.class_token_position == "middle":
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i:i + 1, :, :]
class_i = suffix[i:i + 1, :name_len, :]
suffix_i = suffix[i:i + 1, name_len:, :]
ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
elif self.class_token_position == "front":
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i:i + 1, :, :]
class_i = suffix[i:i + 1, :name_len, :]
suffix_i = suffix[i:i + 1, name_len:, :]
ctx_i = ctx[i:i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else:
raise ValueError
return prompts
CUSTOM_TEMPLATES = {
# "OxfordPets": "a photo of a {}, a type of pet.",
"OxfordPets": "a type of pet, a photo of a {}.",
# "OxfordFlowers": "a photo of a {}, a type of flower.",
"OxfordFlowers": "a type of flower, a photo of a {}.",
"FGVCAircraft": "a type of aircraft, a photo of a {}.",
"DescribableTextures": "a texture of {}.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
# "Food101": "a photo of {}, a type of food.",
"Food101": "a type of food, a photo of {}.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image):
image_features = self.image_encoder(image.type(self.dtype))
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features = self.text_encoder(prompts, tokenized_prompts)
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
logit_scale = self.logit_scale.exp()
logits = logit_scale * image_features @ text_features.t()
return logits
@TRAINER_REGISTRY.register()
class CoOp(TrainerX):
"""Context Optimization (CoOp).
Learning to Prompt for Vision-Language Models
https://arxiv.org/abs/2109.01134
"""
def check_cfg(self, cfg):
assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
for name, param in self.model.named_parameters():
if "prompt_learner" not in name:
param.requires_grad_(False)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.prompt_learner,
cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner,
self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
)
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
prec = self.cfg.TRAINER.COOP.PREC
if prec == "amp":
with autocast():
output = self.model(image)
loss = F.cross_entropy(output, label)
self.optim.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optim)
self.scaler.update()
else:
output = self.model(image)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained model is given"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
================================================
FILE: ProGrad.public/trainers/imagenet_templates.py
================================================
# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
IMAGENET_TEMPLATES = [
"a bad photo of a {}.",
"a photo of many {}.",
"a sculpture of a {}.",
"a photo of the hard to see {}.",
"a low resolution photo of the {}.",
"a rendering of a {}.",
"graffiti of a {}.",
"a bad photo of the {}.",
"a cropped photo of the {}.",
"a tattoo of a {}.",
"the embroidered {}.",
"a photo of a hard to see {}.",
"a bright photo of a {}.",
"a photo of a clean {}.",
"a photo of a dirty {}.",
"a dark photo of the {}.",
"a drawing of a {}.",
"a photo of my {}.",
"the plastic {}.",
"a photo of the cool {}.",
"a close-up photo of a {}.",
"a black and white photo of the {}.",
"a painting of the {}.",
"a painting of a {}.",
"a pixelated photo of the {}.",
"a sculpture of the {}.",
"a bright photo of the {}.",
"a cropped photo of a {}.",
"a plastic {}.",
"a photo of the dirty {}.",
"a jpeg corrupted photo of a {}.",
"a blurry photo of the {}.",
"a photo of the {}.",
"a good photo of the {}.",
"a rendering of the {}.",
"a {} in a video game.",
"a photo of one {}.",
"a doodle of a {}.",
"a close-up photo of the {}.",
"a photo of a {}.",
"the origami {}.",
"the {} in a video game.",
"a sketch of a {}.",
"a doodle of the {}.",
"a origami {}.",
"a low resolution photo of a {}.",
"the toy {}.",
"a rendition of the {}.",
"a photo of the clean {}.",
"a photo of a large {}.",
"a rendition of a {}.",
"a photo of a nice {}.",
"a photo of a weird {}.",
"a blurry photo of a {}.",
"a cartoon {}.",
"art of a {}.",
"a sketch of the {}.",
"a embroidered {}.",
"a pixelated photo of a {}.",
"itap of the {}.",
"a jpeg corrupted photo of the {}.",
"a good photo of a {}.",
"a plushie {}.",
"a photo of the nice {}.",
"a photo of the small {}.",
"a photo of the weird {}.",
"the cartoon {}.",
"art of the {}.",
"a drawing of the {}.",
"a photo of the large {}.",
"a black and white photo of a {}.",
"the plushie {}.",
"a dark photo of a {}.",
"itap of a {}.",
"graffiti of the {}.",
"a toy {}.",
"itap of my {}.",
"a photo of a cool {}.",
"a photo of a small {}.",
"a tattoo of the {}.",
]
IMAGENET_TEMPLATES_SELECT = [
"itap of a {}.",
"a bad photo of the {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}.",
]
================================================
FILE: ProGrad.public/trainers/prograd.py
================================================
import os.path as osp
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from torch.nn.modules.loss import _Loss
from tqdm import tqdm
import json
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict())
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.COOP.N_CTX
ctx_init = cfg.TRAINER.COOP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init:
ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
ctx_init = ctx_init.replace(" {}.", "")
ctx_init = ctx_init.replace("_", " ")
prompt_n_ctx = len(ctx_init.split(" "))
assert n_ctx >= prompt_n_ctx, f"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})"
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)
ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +
prompt_n_ctx, :]
prompt_prefix = " ".join(["X"] * (n_ctx - prompt_n_ctx))
prompt_prefix = f"{prompt_prefix} {ctx_init}"
else:
# random initialization
if cfg.TRAINER.COOP.CSC:
print("Initializing class-specific contexts")
ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
else:
print("Initializing a generic context")
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(
dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix",
embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
self.name_lens = name_lens
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
if self.class_token_position == "end":
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=1,
)
elif self.class_token_position == "middle":
half_n_ctx = n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i:i + 1, :, :]
class_i = suffix[i:i + 1, :name_len, :]
suffix_i = suffix[i:i + 1, name_len:, :]
ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
elif self.class_token_position == "front":
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i:i + 1, :, :]
class_i = suffix[i:i + 1, :name_len, :]
suffix_i = suffix[i:i + 1, name_len:, :]
ctx_i = ctx[i:i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else:
raise ValueError
return prompts
CUSTOM_TEMPLATES = {
"OxfordPets": "a type of pet, a photo of a {}.",
"OxfordFlowers": "a type of flower, a photo of a {}.",
"FGVCAircraft": "a type of aircraft, a photo of a {}.",
"DescribableTextures": "a texture of {}.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
"Food101": "a type of food, a photo of {}.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
class CLIP(nn.Module):
def __init__(self, cfg, classnames):
super().__init__()
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.float()
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
print(f"Prompts: {prompts}")
prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
self.text_features = text_features
self.clip_model = clip_model
def forward(self, image):
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
logit_scale = self.clip_model.logit_scale.exp()
text_features = self.text_features
text_features = text_features.to(image_features.device)
logits = logit_scale * image_features @ text_features.t()
return logits
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image):
image_features = self.image_encoder(image.type(self.dtype))
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features = self.text_encoder(prompts, tokenized_prompts)
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
logit_scale = self.logit_scale.exp()
logits = logit_scale * image_features @ text_features.t()
return logits
class ProGradLoss(_Loss):
def __init__(self, T):
super(ProGradLoss, self).__init__()
self.T = T
def forward(self, stu_logits, tea_logits, label):
xe_loss = F.cross_entropy(stu_logits, label)
tea_prob = F.softmax(tea_logits / self.T, dim=-1)
kl_loss = -tea_prob * F.log_softmax(stu_logits / self.T,
-1) * self.T * self.T
kl_loss = kl_loss.sum(1).mean()
return xe_loss, kl_loss
@TRAINER_REGISTRY.register()
class ProGrad(TrainerX):
"""Projected Gradient for few-shot CLIP
"""
def check_cfg(self, cfg):
assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building zeroshot CLIP")
self.zs_clip = CLIP(cfg, classnames)
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in ZS Clip model")
for name, param in self.zs_clip.named_parameters():
param.requires_grad_(False)
print("Turning off gradients in CoOp model")
for name, param in self.model.named_parameters():
if "prompt_learner" not in name:
param.requires_grad_(False)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.prompt_learner,
cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
self.zs_clip = self.zs_clip.cuda()
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner,
self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Multiple GPUs detected (n_gpus={device_count}), use all of them!"
)
self.model = nn.DataParallel(self.model)
self.zs_clip = nn.DataParallel(self.zs_clip)
# build criterion
if cfg.LOSS.NAME == "prograd":
self.criterion = ProGradLoss(T=cfg.LOSS.T)
else:
raise NotImplementedError
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
prec = self.cfg.TRAINER.COOP.PREC
if prec == "amp":
with autocast():
output = self.model(image)
with torch.no_grad():
zs_clip_output = self.zs_clip(image)
loss = self.criterion(output, zs_clip_output.detach(), label)
self.optim.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optim)
self.scaler.update()
else:
output = self.model(image)
with torch.no_grad():
zs_clip_output = self.zs_clip(image)
xe_loss, kl_loss = self.criterion(output,
zs_clip_output.detach(),
label)
self.prograd_backward_and_update(xe_loss, kl_loss,
self.cfg.LOSS.LAMBDA)
loss_summary = {
"xe_loss": xe_loss.item(),
"kl_loss": kl_loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained model is given"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
================================================
FILE: ProGrad.public/trainers/zsclip.py
================================================
import torch
import torch.nn as nn
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.model import convert_weights
from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
CUSTOM_TEMPLATES = {
# "OxfordPets": "a photo of a {}, a type of pet.",
"OxfordPets": "a type of pet, a photo of a {}.",
# "OxfordFlowers": "a photo of a {}, a type of flower.",
"OxfordFlowers": "a type of flower, a photo of a {}.",
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
"DescribableTextures": "{} texture.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
# "Food101": "a photo of {}, a type of food.",
"Food101": "a type of food, a photo of {}.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
print(f"Prompts: {prompts}")
prompts = torch.cat([clip.tokenize(p) for p in prompts])
prompts = prompts.to(self.device)
with torch.no_grad():
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
self.text_features = text_features
self.clip_model = clip_model
def model_inference(self, image):
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
logit_scale = self.clip_model.logit_scale.exp()
logits = logit_scale * image_features @ self.text_features.t()
return logits
@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
"""Prompt ensembling."""
# templates = IMAGENET_TEMPLATES
templates = IMAGENET_TEMPLATES_SELECT
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
for params in clip_model.parameters():
params.requires_grad_(False)
# add custom-made prompt
if cfg.DATASET.NAME != "ImageNet":
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
num_temp = len(self.templates)
print(f"Prompt ensembling (n={num_temp})")
mean_text_features = 0
for i, temp in enumerate(self.templates):
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
prompts = torch.cat([clip.tokenize(p)
for p in prompts]).to(self.device)
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
mean_text_features = mean_text_features + text_features
mean_text_features = mean_text_features / num_temp
mean_text_features = mean_text_features / mean_text_features.norm(
dim=-1, keepdim=True)
self.text_features = mean_text_features
self.clip_model = clip_model
================================================
FILE: readme.md
================================================
# [ICCV23] Prompt-aligned Gradient for Prompt Tuning
We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the “general direction”, which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods.

[[paper link]](https://doi.org/10.48550/arxiv.2205.14865)
The codes are organized into two folders:
1. [Dassl.ProGrad.pytorch](Dassl.ProGrad.pytorch/) is the modified toolbox of [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch).
2. [ProGrad.public](ProGrad.public/). To get the results in our paper, follow the [README.md](ProGrad.public/README.md) under [ProGrad.public/](ProGrad.public/) to set the environment.
## Citation
If you find our paper or this project helps your research, please kindly consider citing our paper in your publication.
```
@inproceedings{https://doi.org/10.48550/arxiv.2205.14865,
author = {Zhu, Beier and Niu, Yulei and Han, Yucheng and Wu, Yue and Zhang, Hanwang},
title = {Prompt-aligned Gradient for Prompt Tuning},
publisher = {International Conference on Computer Vision},
year = {2023},
}
```
## Acknowledgement
Our codes are built on top of [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch).